Reranking search results using the MS MARCO cross-encoder model

A reranking pipeline can rerank search results, providing a relevance score for each document in the search results with respect to the search query. The relevance score is calculated by a cross-encoder model.

This tutorial illustrates how to use the Hugging Face ms-marco-MiniLM-L-6-v2 model in a reranking pipeline.

Replace the placeholders beginning with the prefix your_ with your own values.

Prerequisite

Before you start, deploy the model on Amazon SageMaker. For better performance, use a GPU.

Run the following code to deploy the model on Amazon SageMaker:

  1. import sagemaker
  2. import boto3
  3. from sagemaker.huggingface import HuggingFaceModel
  4. sess = sagemaker.Session()
  5. role = sagemaker.get_execution_role()
  6. hub = {
  7. 'HF_MODEL_ID':'cross-encoder/ms-marco-MiniLM-L-6-v2',
  8. 'HF_TASK':'text-classification'
  9. }
  10. huggingface_model = HuggingFaceModel(
  11. transformers_version='4.37.0',
  12. pytorch_version='2.1.0',
  13. py_version='py310',
  14. env=hub,
  15. role=role,
  16. )
  17. predictor = huggingface_model.deploy(
  18. initial_instance_count=1, # number of instances
  19. instance_type='ml.m5.xlarge' # ec2 instance type
  20. )

copy

Note the model inference endpoint; you’ll use it to create a connector in the next step.

Step 1: Create a connector and register the model

First, create a connector for the model, providing the inference endpoint and your AWS credentials:

  1. POST /_plugins/_ml/connectors/_create
  2. {
  3. "name": "Sagemaker cross-encoder model",
  4. "description": "Test connector for Sagemaker cross-encoder model",
  5. "version": 1,
  6. "protocol": "aws_sigv4",
  7. "credential": {
  8. "access_key": "your_access_key",
  9. "secret_key": "your_secret_key",
  10. "session_token": "your_session_token"
  11. },
  12. "parameters": {
  13. "region": "your_sagemkaer_model_region_like_us-west-2",
  14. "service_name": "sagemaker"
  15. },
  16. "actions": [
  17. {
  18. "action_type": "predict",
  19. "method": "POST",
  20. "url": "your_sagemaker_model_inference_endpoint_created_in_last_step",
  21. "headers": {
  22. "content-type": "application/json"
  23. },
  24. "request_body": "{ \"inputs\": ${parameters.inputs} }",
  25. "pre_process_function": "\n String escape(def input) { \n if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n return input;\n }\n\n String query = params.query_text;\n StringBuilder builder = new StringBuilder('[');\n \n for (int i=0; i<params.text_docs.length; i ++) {\n builder.append('{\"text\":\"');\n builder.append(escape(query));\n builder.append('\", \"text_pair\":\"');\n builder.append(escape(params.text_docs[i]));\n builder.append('\"}');\n if (i<params.text_docs.length - 1) {\n builder.append(',');\n }\n }\n builder.append(']');\n \n def parameters = '{ \"inputs\": ' + builder + ' }';\n return '{\"parameters\": ' + parameters + '}';\n ",
  26. "post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n \n def resultBuilder = new StringBuilder('[ ');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
  27. }
  28. ]
  29. }

copy

Next, use the connector ID from the response to register and deploy the model:

  1. POST /_plugins/_ml/models/_register?deploy=true
  2. {
  3. "name": "Sagemaker Cross-Encoder model",
  4. "function_name": "remote",
  5. "description": "test rerank model",
  6. "connector_id": "your_connector_id"
  7. }

copy

Note the model ID in the response; you’ll use it in the following steps.

To test the model, call the Predict API:

  1. POST _plugins/_ml/models/your_model_id/_predict
  2. {
  3. "parameters": {
  4. "inputs": [
  5. {
  6. "text": "I like you",
  7. "text_pair": "I hate you"
  8. },
  9. {
  10. "text": "I like you",
  11. "text_pair": "I love you"
  12. }
  13. ]
  14. }
  15. }

copy

Each item in the inputs array comprises a query_text and a text_docs string, separated by a ` . `

Alternatively, you can test the model as follows:

  1. POST _plugins/_ml/_predict/text_similarity/your_model_id
  2. {
  3. "query_text": "I like you",
  4. "text_docs": ["I hate you", "I love you"]
  5. }

copy

The connector pre_process_function transforms the input into the format required by the inputs parameter shown in the previous Predict API request.

By default, the SageMaker model output is in the following format:

  1. [
  2. {
  3. "label": "LABEL_0",
  4. "score": 0.054037678986787796
  5. },
  6. {
  7. "label": "LABEL_0",
  8. "score": 0.5877784490585327
  9. }
  10. ]

The connector pre_process_function transforms the model output into the following format that can be interpreted by the rerank processor:

  1. {
  2. "inference_results": [
  3. {
  4. "output": [
  5. {
  6. "name": "similarity",
  7. "data_type": "FLOAT32",
  8. "shape": [
  9. 1
  10. ],
  11. "data": [
  12. 0.054037678986787796
  13. ]
  14. },
  15. {
  16. "name": "similarity",
  17. "data_type": "FLOAT32",
  18. "shape": [
  19. 1
  20. ],
  21. "data": [
  22. 0.5877784490585327
  23. ]
  24. }
  25. ],
  26. "status_code": 200
  27. }
  28. ]
  29. }

The response contains two similarity outputs. For each similarity output, the data array contains a relevance score for each document against the query. The similarity outputs are provided in the order of the input documents: The first similarity result pertains to the first document.

Step 2: Configure a reranking pipeline

Follow these steps to configure a reranking pipeline.

Step 2.1: Ingest test data

Send a bulk request to ingest test data:

  1. POST _bulk
  2. { "index": { "_index": "my-test-data" } }
  3. { "passage_text" : "Carson City is the capital city of the American state of Nevada." }
  4. { "index": { "_index": "my-test-data" } }
  5. { "passage_text" : "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan." }
  6. { "index": { "_index": "my-test-data" } }
  7. { "passage_text" : "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district." }
  8. { "index": { "_index": "my-test-data" } }
  9. { "passage_text" : "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." }

copy

Step 2.2: Create a reranking pipeline

Create a reranking pipeline using the MS MARCO cross-encoder model:

  1. PUT /_search/pipeline/rerank_pipeline_sagemaker
  2. {
  3. "description": "Pipeline for reranking with Sagemaker cross-encoder model",
  4. "response_processors": [
  5. {
  6. "rerank": {
  7. "ml_opensearch": {
  8. "model_id": "your_model_id_created_in_step1"
  9. },
  10. "context": {
  11. "document_fields": ["passage_text"]
  12. }
  13. }
  14. }
  15. ]
  16. }

copy

If you provide multiple field names in document_fields, then the values of all fields are first concatenated, after which reranking is performed.

Step 2.3: Test the reranking

To limit the number of returned results, you can specify the size parameter. For example, set "size": 4 to return the top four documents:

  1. GET my-test-data/_search?search_pipeline=rerank_pipeline_sagemaker
  2. {
  3. "query": {
  4. "match_all": {}
  5. },
  6. "size": 4,
  7. "ext": {
  8. "rerank": {
  9. "query_context": {
  10. "query_text": "What is the capital of the United States?"
  11. }
  12. }
  13. }
  14. }

copy

The response contains the four most relevant documents:

  1. {
  2. "took": 3,
  3. "timed_out": false,
  4. "_shards": {
  5. "total": 1,
  6. "successful": 1,
  7. "skipped": 0,
  8. "failed": 0
  9. },
  10. "hits": {
  11. "total": {
  12. "value": 4,
  13. "relation": "eq"
  14. },
  15. "max_score": 0.9997217,
  16. "hits": [
  17. {
  18. "_index": "my-test-data",
  19. "_id": "U0xye5AB9ZeWZdmDjWZn",
  20. "_score": 0.9997217,
  21. "_source": {
  22. "passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
  23. }
  24. },
  25. {
  26. "_index": "my-test-data",
  27. "_id": "VExye5AB9ZeWZdmDjWZn",
  28. "_score": 0.55655104,
  29. "_source": {
  30. "passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
  31. }
  32. },
  33. {
  34. "_index": "my-test-data",
  35. "_id": "UUxye5AB9ZeWZdmDjWZn",
  36. "_score": 0.115356825,
  37. "_source": {
  38. "passage_text": "Carson City is the capital city of the American state of Nevada."
  39. }
  40. },
  41. {
  42. "_index": "my-test-data",
  43. "_id": "Ukxye5AB9ZeWZdmDjWZn",
  44. "_score": 0.00021142483,
  45. "_source": {
  46. "passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
  47. }
  48. }
  49. ]
  50. },
  51. "profile": {
  52. "shards": []
  53. }
  54. }

To compare these results to results without reranking, run the search without a reranking pipeline:

  1. GET my-test-data/_search
  2. {
  3. "query": {
  4. "match_all": {}
  5. },
  6. "ext": {
  7. "rerank": {
  8. "query_context": {
  9. "query_text": "What is the capital of the United States?"
  10. }
  11. }
  12. }
  13. }

copy

The first document in the response pertains to Carson City, which is not the capital of the United States:

  1. {
  2. "took": 1,
  3. "timed_out": false,
  4. "_shards": {
  5. "total": 1,
  6. "successful": 1,
  7. "skipped": 0,
  8. "failed": 0
  9. },
  10. "hits": {
  11. "total": {
  12. "value": 4,
  13. "relation": "eq"
  14. },
  15. "max_score": 1,
  16. "hits": [
  17. {
  18. "_index": "my-test-data",
  19. "_id": "UUxye5AB9ZeWZdmDjWZn",
  20. "_score": 1,
  21. "_source": {
  22. "passage_text": "Carson City is the capital city of the American state of Nevada."
  23. }
  24. },
  25. {
  26. "_index": "my-test-data",
  27. "_id": "Ukxye5AB9ZeWZdmDjWZn",
  28. "_score": 1,
  29. "_source": {
  30. "passage_text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan."
  31. }
  32. },
  33. {
  34. "_index": "my-test-data",
  35. "_id": "U0xye5AB9ZeWZdmDjWZn",
  36. "_score": 1,
  37. "_source": {
  38. "passage_text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
  39. }
  40. },
  41. {
  42. "_index": "my-test-data",
  43. "_id": "VExye5AB9ZeWZdmDjWZn",
  44. "_score": 1,
  45. "_source": {
  46. "passage_text": "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
  47. }
  48. }
  49. ]
  50. }
  51. }