如何在 C#/.NET 中执行向量搜索?

最后更新时间 2024 年 4 月 20 日

目标

了解如何在 .NET 平台上使用 C# 编程语言将 Redis 作为向量数据库进行向量搜索。

解决方案

以下示例演示了执行一个简单示例,使用 Redis 作为向量数据库,使用 NRedisStack 客户端库为 .NET 编程语言对句子进行向量嵌入建模。要开始使用示例,请学习如何 设置一个 C#/.NET 项目 以使用 Redis 作为向量数据库。

安装以下库,本文件讨论的示例需要这些库。

cd vector-test

dotnet add package NRedisStack
dotnet add package Microsoft.ML

现在编辑项目文件夹中的 Program.cs 文件,粘贴以下内容

using NRedisStack;
using NRedisStack.RedisStackCommands;
using NRedisStack.Search;
using NRedisStack.Search.Aggregation;
using NRedisStack.Search.Literals.Enums;
using StackExchange.Redis;
using static NRedisStack.Search.Schema;

using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Transforms.Text;
using System.Text.Json;


namespace Redis.SemanticSearch
{
    public static class VssExample
    {
        static void Main() {
            CreateIndex();
            ModelSentences();
            TestSentence();
        }

        private static void CreateIndex(){
            ConnectionMultiplexer redis = ConnectionMultiplexer.Connect("localhost:6379");
            IDatabase db = redis.GetDatabase();

            var schema = new Schema()
            .AddTextField(new FieldName("content", "content"))
            .AddTagField(new FieldName("genre", "genre"))
            .AddVectorField("embedding", VectorField.VectorAlgo.HNSW,
                new Dictionary<string, object>()
                {
                    ["TYPE"] = "FLOAT32",
                    ["DIM"] = "150",
                    ["DISTANCE_METRIC"] = "L2"
                }
            );

            SearchCommands ft = db.FT();
            ft.Create(
                "vector_idx",
                new FTCreateParams().On(IndexDataType.HASH).Prefix("doc:"),
                schema);
        }
        private static byte[] GetEmbedding(PredictionEngine<TextData, TransformedTextData> model, string sentence)
        {
            // Call the prediction API to convert the text into embedding vector.
            var data = new TextData()
            {
                Text = sentence
            };
            var prediction = model.Predict(data);

            // Convert prediction.Features to a binary blob
            float[] floatArray = Array.ConvertAll(prediction.Features, x => (float)x);
            byte[] byteArray = new byte[floatArray.Length * sizeof(float)];
            Buffer.BlockCopy(floatArray, 0, byteArray, 0, byteArray.Length);

            return byteArray;
        }

        private static PredictionEngine<TextData, TransformedTextData> GetPredictionEngine(){
            ConnectionMultiplexer redis = ConnectionMultiplexer.Connect("localhost");
            IDatabase db = redis.GetDatabase();

            // Create a new ML context, for ML.NET operations. It can be used for
            // exception tracking and logging, as well as the source of randomness.
            var mlContext = new MLContext();

            // Create an empty list as the dataset
            var emptySamples = new List<TextData>();

            // Convert sample list to an empty IDataView.
            var emptyDataView = mlContext.Data.LoadFromEnumerable(emptySamples);

            // A pipeline for converting text into a 150-dimension embedding vector
            var textPipeline = mlContext.Transforms.Text.NormalizeText("Text")
                .Append(mlContext.Transforms.Text.TokenizeIntoWords("Tokens",
                    "Text"))
                .Append(mlContext.Transforms.Text.ApplyWordEmbedding("Features",
                    "Tokens", WordEmbeddingEstimator.PretrainedModelKind
                    .SentimentSpecificWordEmbedding));

            // Fit to data.
            var textTransformer = textPipeline.Fit(emptyDataView);

            // Create the prediction engine to get the embedding vector from the input text/string.
            var predictionEngine = mlContext.Model.CreatePredictionEngine<TextData,
                TransformedTextData>(textTransformer);

            return predictionEngine;
        }

        public static void ModelSentences()
        {
            ConnectionMultiplexer redis = ConnectionMultiplexer.Connect("localhost");
            IDatabase db = redis.GetDatabase();

           var predictionEngine = GetPredictionEngine();

            // Create data
            var hash1 = new HashEntry[] { 
                new HashEntry("content", "That is a very happy person"), 
                new HashEntry("genre", "persons"),
                new HashEntry("embedding", GetEmbedding(predictionEngine, "That is a very happy person")),
            };
            db.HashSet("doc:1", hash1);

            var hash2 = new HashEntry[] { 
                new HashEntry("content", "That is a happy dog"), 
                new HashEntry("genre", "pets"),
                new HashEntry("embedding", GetEmbedding(predictionEngine, "That is a happy dog")),
            };
            db.HashSet("doc:2", hash2);

            var hash3 = new HashEntry[] { 
                new HashEntry("content", "Today is a sunny day"), 
                new HashEntry("genre", "weather"),
                new HashEntry("embedding", GetEmbedding(predictionEngine, "Today is a sunny day")),
            };
            db.HashSet("doc:3", hash3);
        }

        private static void TestSentence(){
            ConnectionMultiplexer redis = ConnectionMultiplexer.Connect("localhost");
            IDatabase db = redis.GetDatabase();
            var predictionEngine = GetPredictionEngine();

            SearchCommands ft = db.FT();
            var res = ft.Search("vector_idx",
                        new Query("*=>[KNN 3 @embedding $query_vec AS score]")
                        .AddParam("query_vec", GetEmbedding(predictionEngine, "That is a happy person"))
                        .ReturnFields(new FieldName("content", "content"), new FieldName("score", "score"))
                        .SetSortBy("score")
                        .Dialect(2));

            foreach (var doc in res.Documents) {
                Console.Write($"id: {doc.Id}, ");
                foreach (var item in doc.GetProperties()) {
                    Console.Write($" {item.Value}");
                }
                Console.WriteLine();
            }
        }

        private class TextData
        {
            public string Text { get; set; }
        }

        private class TransformedTextData : TextData
        {
            public float[] Features { get; set; }
        }
    }
}

您现在可以执行项目

dotnet run

请注意,示例在第一次执行时似乎会挂起,但它只是需要一些时间来下载嵌入模型。

示例将三个句子(“那是一个非常快乐的人”、“那是一只快乐的狗”、“今天是一个阳光明媚的日子”)存储为 Redis 哈希,并找到从建模句子中测试句子“那是一个非常快乐的人”的相似度。向量搜索配置为返回三个结果(KNN 3),正如预期的那样,最小距离对应于两个句子比较时最高的语义相似度。

id: doc:1,  4.30777168274 That is a very happy person
id: doc:2,  25.9752807617 That is a happy dog
id: doc:3,  68.8638000488 Today is a sunny day

参考资料

学习使用 C#/.NET 编程的 Redis 资源

更多关于 ML 模型的信息