Batch predict

This is an experimental feature and is not recommended for use in a production environment. For updates on the progress of the feature or if you want to leave feedback, see the associated GitHub issue.

ML Commons can perform inference on large datasets in an offline asynchronous mode using a model deployed on external model servers. To use the Batch Predict API, you must provide the model_id for an externally hosted model. Amazon SageMaker, Cohere, and OpenAI are currently the only verified external servers that support this API.

For information about user access for this API, see Model access control considerations.

For information about externally hosted models, see Connecting to externally hosted models.

For instructions on how set up batch inference and connector blueprints, see the following:

Path and HTTP methods

  1. POST /_plugins/_ml/models/<model_id>/_batch_predict

Prerequisites

Before using the Batch Predict API, you need to create a connector to the externally hosted model. For each action, specify the action_type parameter that describes the action:

  • batch_predict: Runs the batch predict operation.
  • batch_predict_status: Checks the batch predict operation status.
  • cancel_batch_predict: Cancels the batch predict operation.

For example, to create a connector to an OpenAI text-embedding-ada-002 model, send the following request. The cancel_batch_predict action is optional and supports canceling the batch job running on OpenAI:

  1. POST /_plugins/_ml/connectors/_create
  2. {
  3. "name": "OpenAI Embedding model",
  4. "description": "OpenAI embedding model for testing offline batch",
  5. "version": "1",
  6. "protocol": "http",
  7. "parameters": {
  8. "model": "text-embedding-ada-002",
  9. "input_file_id": "<your input file id in OpenAI>",
  10. "endpoint": "/v1/embeddings"
  11. },
  12. "credential": {
  13. "openAI_key": "<your openAI key>"
  14. },
  15. "actions": [
  16. {
  17. "action_type": "predict",
  18. "method": "POST",
  19. "url": "https://api.openai.com/v1/embeddings",
  20. "headers": {
  21. "Authorization": "Bearer ${credential.openAI_key}"
  22. },
  23. "request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }",
  24. "pre_process_function": "connector.pre_process.openai.embedding",
  25. "post_process_function": "connector.post_process.openai.embedding"
  26. },
  27. {
  28. "action_type": "batch_predict",
  29. "method": "POST",
  30. "url": "https://api.openai.com/v1/batches",
  31. "headers": {
  32. "Authorization": "Bearer ${credential.openAI_key}"
  33. },
  34. "request_body": "{ \"input_file_id\": \"${parameters.input_file_id}\", \"endpoint\": \"${parameters.endpoint}\", \"completion_window\": \"24h\" }"
  35. },
  36. {
  37. "action_type": "batch_predict_status",
  38. "method": "GET",
  39. "url": "https://api.openai.com/v1/batches/${parameters.id}",
  40. "headers": {
  41. "Authorization": "Bearer ${credential.openAI_key}"
  42. }
  43. },
  44. {
  45. "action_type": "cancel_batch_predict",
  46. "method": "POST",
  47. "url": "https://api.openai.com/v1/batches/${parameters.id}/cancel",
  48. "headers": {
  49. "Authorization": "Bearer ${credential.openAI_key}"
  50. }
  51. }
  52. ]
  53. }

copy

The response contains a connector ID that you’ll use in the next steps:

  1. {
  2. "connector_id": "XU5UiokBpXT9icfOM0vt"
  3. }

Next, register an externally hosted model and provide the connector ID of the created connector:

  1. POST /_plugins/_ml/models/_register?deploy=true
  2. {
  3. "name": "OpenAI model for realtime embedding and offline batch inference",
  4. "function_name": "remote",
  5. "description": "OpenAI text embedding model",
  6. "connector_id": "XU5UiokBpXT9icfOM0vt"
  7. }

copy

The response contains the task ID for the register operation:

  1. {
  2. "task_id": "rMormY8B8aiZvtEZIO_j",
  3. "status": "CREATED",
  4. "model_id": "lyjxwZABNrAVdFa9zrcZ"
  5. }

To check the status of the operation, provide the task ID to the Tasks API. Once the registration is complete, the task state changes to COMPLETED.

Example request

Once you have completed the prerequisite steps, you can call the Batch Predict API. The parameters in the batch predict request override those defined in the connector:

  1. POST /_plugins/_ml/models/lyjxwZABNrAVdFa9zrcZ/_batch_predict
  2. {
  3. "parameters": {
  4. "model": "text-embedding-3-large"
  5. }
  6. }

copy

Example response

The response contains the task ID for the batch predict operation:

  1. {
  2. "task_id": "KYZSv5EBqL2d0mFvs80C",
  3. "status": "CREATED"
  4. }

To check the status of the batch predict job, provide the task ID to the Tasks API. You can find the job details in the remote_job field in the task. Once the prediction is complete, the task state changes to COMPLETED.

Example request

  1. GET /_plugins/_ml/tasks/KYZSv5EBqL2d0mFvs80C

copy

Example response

The response contains the batch predict operation details in the remote_job field:

  1. {
  2. "model_id": "JYZRv5EBqL2d0mFvKs1E",
  3. "task_type": "BATCH_PREDICTION",
  4. "function_name": "REMOTE",
  5. "state": "RUNNING",
  6. "input_type": "REMOTE",
  7. "worker_node": [
  8. "Ee5OCIq0RAy05hqQsNI1rg"
  9. ],
  10. "create_time": 1725491751455,
  11. "last_update_time": 1725491751455,
  12. "is_async": false,
  13. "remote_job": {
  14. "cancelled_at": null,
  15. "metadata": null,
  16. "request_counts": {
  17. "total": 3,
  18. "completed": 3,
  19. "failed": 0
  20. },
  21. "input_file_id": "file-XXXXXXXXXXXX",
  22. "output_file_id": "file-XXXXXXXXXXXXX",
  23. "error_file_id": null,
  24. "created_at": 1725491753,
  25. "in_progress_at": 1725491753,
  26. "expired_at": null,
  27. "finalizing_at": 1725491757,
  28. "completed_at": null,
  29. "endpoint": "/v1/embeddings",
  30. "expires_at": 1725578153,
  31. "cancelling_at": null,
  32. "completion_window": "24h",
  33. "id": "batch_XXXXXXXXXXXXXXX",
  34. "failed_at": null,
  35. "errors": null,
  36. "object": "batch",
  37. "status": "in_progress"
  38. }
  39. }

For the definition of each field in the result, see OpenAI Batch API. Once the batch inference is complete, you can download the output by calling the OpenAI Files API and providing the file name specified in the id field of the response.

Canceling a batch predict job

You can also cancel the batch predict operation running on the remote platform using the task ID returned by the batch predict request. To add this capability, set the action_type to cancel_batch_predict in the connector configuration when creating the connector.

Example request

  1. POST /_plugins/_ml/tasks/KYZSv5EBqL2d0mFvs80C/_cancel_batch

copy

Example response

  1. {
  2. "status": "OK"
  3. }