Hello, my goal is to fine-tune the Gemini models, 1.5-Pro and 1.5-Flash, perhaps 1.0-Vision-pro if possible to improve it's output for scoring the similarity between two .png images. I am not experienced within this space but with Claude's help I have generated the following python scripts.
tactile map template
hand drawing of tactile map
For each training example, I have an input of TWO .png images and an output of a score value 1 to 10. However, I am trying a completely different way for the inputs into the tuning/training since I had to do some human-rating/hand-scoring of the comparisons of about 1000 images in Matlab where it then stores both of the hand drawing image and its template's image file directory as the input and the similarity score as the output in a Matlab .mat file. The python scripts below converts the content in the .mat into the JSON files to work with the input criteria structure for the Gemini tuning.
I am running into problems when running it and I am not sure what is the issue. If anyone can spot what is wrong and potential solutions, please let me know! Thanks in advance.
Below are the two python scripts, the main one right below, and the secondary script at the end.
import os
import scipy.io
import numpy as np
import base64
import re
import time
import traceback
import logging
import json
from google.cloud import aiplatform
from google.cloud import storage
from google.auth import default
# Set up logging
logging.basicConfig(filename='tactile_map_tuning_log.txt', level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s')
def init_vertex_ai():
try:
aiplatform.init(project='image-similarity-435221',
location='us-central1',
staging_bucket='gs://human_rated_tactile_map_similarity_scores')
logging.info("Vertex AI initialized successfully.")
except Exception as e:
logging.error(f"Error initializing Vertex AI: {str(e)}")
raise
def create_valid_id(base_id):
valid_id = re.sub(r'[^a-z0-9-]', '', base_id.lower())
valid_id = 'model-' + valid_id if not valid_id[0].isalpha() else valid_id
return valid_id[:40]
def truncate_text(text, max_length=32000):
return text[:max_length]
def image_to_base64(image_path):
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
except FileNotFoundError:
logging.error(f"File not found: {image_path}")
return None
def upload_to_gcs(bucket_name, source_file_name, destination_blob_name):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_filename(source_file_name)
gcs_path = f"gs://{bucket_name}/{destination_blob_name}"
logging.info(f"File {source_file_name} uploaded to {gcs_path}.")
return gcs_path
def create_training_example(item):
try:
score_full_map = item['scoreFullMap'].item()
score_shortest_path = item['scoreShortestPath'].item()
if isinstance(score_full_map, np.ndarray):
score_full_map = score_full_map.item() if score_full_map.size > 0 else np.nan
if isinstance(score_shortest_path, np.ndarray):
score_shortest_path = score_shortest_path.item() if score_shortest_path.size > 0 else np.nan
score = score_full_map if not np.isnan(score_full_map) else score_shortest_path
if np.isnan(score):
return None
raw_file = item['rawFileName'].item()
template_file = item['templateFileName'].item()
raw_image_base64 = image_to_base64(raw_file)
template_image_base64 = image_to_base64(template_file)
if raw_image_base64 is None or template_image_base64 is None:
return None
return {
'hand_drawing': raw_image_base64,
'template': template_image_base64,
'similarity_score': int(score)
}
except Exception as e:
logging.error(f"Error processing item: {str(e)}")
logging.error(f"Item content: {item}")
return None
def save_training_data(training_data, filename='training_data.json'):
with open(filename, 'w') as f:
json.dump(training_data, f)
logging.info(f"Training data saved to {filename}")
def main():
try:
init_vertex_ai()
mat_file_path = r"L:\NEI\Navigate\Training\Recordings\fMRI HandScoring\done_fMRI_scores_Full_Group_ML"
mat_data = scipy.io.loadmat(mat_file_path)
data = mat_data['data']
logging.info(f"Type of data: {type(data)}")
logging.info(f"Shape of data: {data.shape}")
logging.info(f"Data dtype: {data.dtype}")
training_data = []
for row in range(data.shape[0]):
for col in range(data.shape[1]):
item = data[row, col]
example = create_training_example(item)
if example:
training_data.append(example)
if len(training_data) % 10 == 0:
logging.info(f"Processed {len(training_data)} valid examples")
logging.info(f"Total training examples: {len(training_data)}")
if len(training_data) > 0:
logging.info("\nPreparing to start model tuning...")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
local_file_name = os.path.join(current_dir, 'training_data.json')
save_training_data(training_data, local_file_name)
bucket_name = 'human_rated_tactile_map_similarity_scores'
destination_blob_name = 'training_data.json'
gcs_data_path = upload_to_gcs(bucket_name, local_file_name, destination_blob_name)
# Create custom training job
job = aiplatform.CustomTrainingJob(
display_name=f"tactile_map_tuning_{int(time.time())}",
script_path="training_script.py",
container_uri="us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-8:latest", # Changed to CPU version
requirements=["tensorflow", "transformers", "google-cloud-aiplatform", "google-generativeai"],
model_serving_container_image_uri="us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest", # Changed to CPU version
staging_bucket=f"gs://{bucket_name}"
)
logging.info("Training job defined")
logging.info("Starting model training...")
model = job.run(
model_display_name=f"TactileMap_Similarity_Model_{int(time.time())}",
args=[
"--train-data", gcs_data_path,
"--model-dir", f"gs://{bucket_name}/model_output",
"--epochs", "20",
"--batch-size", "16",
"--learning-rate", "0.001",
"--base-model", "gemini-1.5-pro-002"
],
replica_count=1,
machine_type="n1-standard-4",
# Removed accelerator_type and accelerator_count
)
logging.info(f"Model training completed. Model resource name: {model.resource_name}")
except Exception as e:
logging.error(f"Error during model tuning: {str(e)}")
logging.error("Full error details:")
traceback.print_exc()
else:
logging.warning("No valid training examples were generated. Please check your data and file paths.")
except Exception as e:
logging.error(f"An unexpected error occurred: {str(e)}")
logging.error("Full error details:")
traceback.print_exc()
if __name__ == "__main__":
main()
import argparse
import os
import json
import google.generativeai as genai
from google.cloud import storage
def load_data(file_path):
with open(file_path, 'r') as f:
return json.load(f)
def download_blob(bucket_name, source_blob_name, destination_file_name):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)
def main(args):
# Download the training data
local_file = "/tmp/training_data.json"
bucket_name = args.train_data.split("/")[2]
source_blob_name = "/".join(args.train_data.split("/")[3:])
download_blob(bucket_name, source_blob_name, local_file)
# Load the training data
training_data = load_data(local_file)
# Configure the GenerativeAI library
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
# Create the fine-tuning job
model = genai.GenerativeModel(args.base_model)
tuning_job = model.create_tuning_job(
training_data=[
{
"input_text": f"Compare the similarity between two images. Image 1 (hand drawing): {item['hand_drawing']} Image 2 (template): {item['template']}",
"output_text": str(item['similarity_score'])
} for item in training_data
],
options={
"batch_size": args.batch_size,
"learning_rate": args.learning_rate,
"epochs": args.epochs
}
)
print(f"Fine-tuning job created: {tuning_job.name}")
# Wait for the job to complete
tuning_job = tuning_job.result()
if tuning_job.state == "SUCCEEDED":
print(f"Fine-tuning completed. Model name: {tuning_job.tuned_model.name}")
# Save model details
with open(os.path.join(args.model_dir, 'model_details.txt'), 'w') as f:
f.write(f"Model name: {tuning_job.tuned_model.name}\n")
f.write(f"Model display name: {args.base_model}-fine-tuned\n")
else:
print(f"Fine-tuning job failed or was cancelled: {tuning_job.state}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--train-data', type=str, help='GCS path to training data')
parser.add_argument('--model-dir', type=str, help='Directory to save model details')
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
parser.add_argument('--batch-size', type=int, default=16, help='Batch size')
parser.add_argument('--learning-rate', type=float, default=0.001, help='Learning rate')
parser.add_argument('--base-model', type=str, default='gemini-1.5-pro-002', help='Base model to fine-tune')
args = parser.parse_args()
main(args)
User | Count |
---|---|
2 | |
2 | |
1 | |
1 | |
1 |