Get hands-on experience with 20+ free Google Cloud products and $300 in free credit for new customers.

Can't run Batch Prediction job with TUNED gemini model

I can run a batch predictions job just fine with the normal gemini-2.0-flash-001 model.

However, when I replace the model parameter with a custom tuned gemini-2.0-flash model from my project, every request fails:

{"status":"Internal error occurred. Failed to get generateContentResponse: {\"error\": {\"code\": 404, \"message\": \"Endpoint `projects/rd820bc50f78008f2-tp/locations/us-central1/endpoints/llm-bp-endpoint-job-7303480853353463808` not found.\", \"status\": \"NOT_FOUND\"}}","processed_time":"2025-04-30T15:25:13.077+00:00",

Here is my code:

 

#!/usr/bin/env python3
import os
import sys
import time
import argparse
import datetime
from google import genai
from google.genai.types import CreateBatchJobConfig, JobState, HttpOptions

# Constants
PROJECT_ID = "564504826453"
LOCATION = "us-central1"
BUCKET_NAME = "aerial-bucket"
MODEL_ID = "projects/564504826453/locations/us-central1/models/9122741434145832960@1"
#MODEL_ID = "gemini-2.0-flash-001"
DEFAULT_INPUT_JSONL = "dataset/batch_prediction_requests.jsonl"
DEFAULT_OUTPUT_PATH = f"gs://{BUCKET_NAME}/batch_inference/results"

def log_with_timestamp(message):
    """Add timestamp to log messages"""
    timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3]
    print(f"[{timestamp}] {message}")

def run_batch_prediction(input_jsonl_gcs, output_gcs_path):
    """Start and monitor a batch prediction job"""
    # Initialize client with proper authentication for Vertex AI
    client = genai.Client(
        vertexai=True,
        project=PROJECT_ID,
        location=LOCATION,
        http_options=HttpOptions(api_version="v1")
    )
    
    log_with_timestamp(f"Starting batch prediction job using:")
    log_with_timestamp(f"  Model: {MODEL_ID}")
    log_with_timestamp(f"  Input JSONL: {input_jsonl_gcs}")
    log_with_timestamp(f"  Output path: {output_gcs_path}")
    
    # Create batch job
    job = client.batches.create(
        model=MODEL_ID,
        src=input_jsonl_gcs,
        config=CreateBatchJobConfig(dest=output_gcs_path),
    )
    
    job_name = job.name
    log_with_timestamp(f"Job created with name: {job_name}")
    log_with_timestamp(f"Initial job state: {job.state}")
    
    # Define completed states
    completed_states = {
        JobState.JOB_STATE_SUCCEEDED,
        JobState.JOB_STATE_FAILED,
        JobState.JOB_STATE_CANCELLED,
        JobState.JOB_STATE_PAUSED,
    }
    
    # Monitor job progress
    start_time = time.time()
    last_state = job.state
    
    try:
        while job.state not in completed_states:
            time.sleep(30)  # Check status every 30 seconds
            job = client.batches.get(name=job_name)
            
            # Only log if state changed
            if job.state != last_state:
                elapsed_time = time.time() - start_time
                hours, remainder = divmod(elapsed_time, 3600)
                minutes, seconds = divmod(remainder, 60)
                
                log_with_timestamp(f"Job state changed: {last_state} → {job.state} " 
                                  f"(Elapsed: {int(hours)}h {int(minutes)}m {int(seconds)}s)")
                last_state = job.state
            
            # Log progress every 5 minutes regardless of state change
            if time.time() - start_time > 0 and (time.time() - start_time) % 300 < 30:
                elapsed_time = time.time() - start_time
                hours, remainder = divmod(elapsed_time, 3600)
                minutes, seconds = divmod(remainder, 60)
                
                log_with_timestamp(f"Job still running: {job.state} "
                                  f"(Elapsed: {int(hours)}h {int(minutes)}m {int(seconds)}s)")
        
        # Job completed
        total_time = time.time() - start_time
        hours, remainder = divmod(total_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        log_with_timestamp(f"Job completed with state: {job.state}")
        log_with_timestamp(f"Total execution time: {int(hours)}h {int(minutes)}m {int(seconds)}s")
        
        # Show output location
        if job.state == JobState.JOB_STATE_SUCCEEDED:
            log_with_timestamp(f"Results available at: {output_gcs_path}")
            
            # If status is available, show it
            if hasattr(job, 'status') and job.status:
                log_with_timestamp(f"Job status: {job.status}")
            
            return True
        else:
            log_with_timestamp(f"Job did not complete successfully. Final state: {job.state}")
            
            # If error is available, show it
            if hasattr(job, 'error') and job.error:
                log_with_timestamp(f"Error: {job.error}")
            
            return False
            
    except KeyboardInterrupt:
        log_with_timestamp("Job monitoring interrupted. The job will continue running.")
        log_with_timestamp(f"You can check the status later using job name: {job_name}")
        return False
    except Exception as e:
        log_with_timestamp(f"Error monitoring job: {e}")
        log_with_timestamp(f"Job may still be running. Job name: {job_name}")
        return False

def upload_jsonl_to_gcs(local_jsonl_path):
    """Upload the JSONL file to GCS if it's not already there"""
    if local_jsonl_path.startswith("gs://"):
        return local_jsonl_path
    
    # If we need to upload the file to GCS
    from google.cloud import storage
    
    # Initialize storage client
    storage_client = storage.Client(project=PROJECT_ID)
    bucket = storage_client.bucket(BUCKET_NAME)
    
    # Generate GCS path
    filename = os.path.basename(local_jsonl_path)
    blob_name = f"batch_inference/input/{filename}"
    blob = bucket.blob(blob_name)
    
    log_with_timestamp(f"Uploading {local_jsonl_path} to GCS...")
    blob.upload_from_filename(local_jsonl_path)
    
    gcs_uri = f"gs://{BUCKET_NAME}/{blob_name}"
    log_with_timestamp(f"Uploaded to {gcs_uri}")
    
    return gcs_uri

def main():
    parser = argparse.ArgumentParser(description='Run a batch prediction job using Google GenAI API')
    parser.add_argument('--input_jsonl', default=DEFAULT_INPUT_JSONL, 
                        help='Path to the input JSONL file (local or GCS URI)')
    parser.add_argument('--output_gcs', default=DEFAULT_OUTPUT_PATH, 
                        help='GCS path for output results')
    args = parser.parse_args()
    
    # If input is a local file, upload it to GCS first
    if not args.input_jsonl.startswith("gs://"):
        if not os.path.exists(args.input_jsonl):
            log_with_timestamp(f"Error: Input file {args.input_jsonl} not found")
            sys.exit(1)
        
        input_gcs_path = upload_jsonl_to_gcs(args.input_jsonl)
    else:
        input_gcs_path = args.input_jsonl
    
    # Run the batch prediction
    success = run_batch_prediction(input_gcs_path, args.output_gcs)
    
    if success:
        log_with_timestamp("Batch prediction completed successfully")
    else:
        log_with_timestamp("Batch prediction did not complete successfully")
        sys.exit(1)

if __name__ == "__main__":
    main() 

 

 I cannot find anywhere if Batch Predictions is supported with custom tuned models.

I also should say that the type of requests I am doing is multimodal (image and text inputs).

I hope I can get some help.

Thank you 

0 1 97
1 REPLY 1

Hi @luispl77,

Welcome to Google Cloud Community!

Currently, batch prediction for Gemini supports only the following models:

  • Gemini-2.0-flash-lite-001
  • Gemini-2.0-flash-00

You can file a feature request regarding this for enhancement. Before filing, please take note on what to expect when opening an issue. I recommend keeping an eye on the tracker and checking the release notes and documents for the latest updates.

Was this helpful? If so, please accept this answer as “Solution”. If you need additional assistance, reply here within 2 business days and I’ll be happy to help.