如何在 Java 中使用 Jedis 客户端库执行向量搜索?

最后更新于 2024年4月20日

问题

如何在 Java 中使用 Jedis 客户端库执行向量搜索?

答案

创建一个 Java Maven 项目(查看说明以构建一个脚手架项目)并包含以下依赖项(指定所需版本)

    <dependency>
      <groupId>redis.clients</groupId>
      <artifactId>jedis</artifactId>
      <version>5.0.1</version>
    </dependency>
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.24.0</version>
    </dependency>
    <dependency>
      <groupId>ai.djl.huggingface</groupId>
      <artifactId>tokenizers</artifactId>
      <version>0.24.0</version>
    </dependency>

该示例将存储三个句子(“那是一个非常快乐的人”、“那是一只快乐的狗”、“今天是晴天”)作为 Redis 哈希,并从建模的句子中找到测试句子“那是一个快乐的人”的相似度。向量搜索配置为返回三个结果(KNN 3

package com.redis.app;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.UnifiedJedis;
import redis.clients.jedis.search.*;
import redis.clients.jedis.search.schemafields.*;
import redis.clients.jedis.HostAndPort;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;

import java.util.HashMap;
import java.util.List;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;


public class App {
    public static byte[] floatArrayToByteArray(float[] input) {
        byte[] bytes = new byte[Float.BYTES * input.length];
        ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input);
        return bytes;
    }

    public static byte[] longArrayToByteArray(long[] input) {
        return floatArrayToByteArray(longArrayToFloatArray(input));
    }

    public static float[] longArrayToFloatArray(long[] input) {
        float[] floats = new float[input.length];
        for (int i = 0; i < input.length; i++) {
            floats[i] = input[i];
        }
        return floats;
    }

    public static void main(String[] args) {
        // Connect to Redis
        UnifiedJedis unifiedjedis = new UnifiedJedis(System.getenv().getOrDefault("REDIS_URL", "redis://localhost:6379"));

        // Create the index
        IndexDefinition definition = new IndexDefinition().setPrefixes(new String[]{"doc:"});
        Map<String, Object> attr = new HashMap<>();
        attr.put("TYPE", "FLOAT32");
        attr.put("DIM", 768);
        attr.put("DISTANCE_METRIC", "L2");
        attr.put("INITIAL_CAP", 3);
        Schema schema = new Schema().addTextField("content", 1).addTagField("genre").addHNSWVectorField("embedding", attr);                      

        // Catch exceptions if the index exists
        try {
            unifiedjedis.ftCreate("vector_idx", IndexOptions.defaultOptions().setDefinition(definition), schema);
        }
        catch(Exception e) {
            System.out.println(e.getMessage());
        }

        // Create the embedding model
        Map<String, String> options = Map.of("maxLength", "768",  "modelMaxLength", "768");
        HuggingFaceTokenizer sentenceTokenizer = HuggingFaceTokenizer.newInstance("sentence-transformers/all-mpnet-base-v2", options);

        // Train with sentences
        String sentence1 = "That is a very happy person";
        unifiedjedis.hset("doc:1", Map.of(  "content", sentence1, "genre", "persons"));
        unifiedjedis.hset("doc:1".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence1).getIds()));

        String sentence2 = "That is a happy dog";
        unifiedjedis.hset("doc:2", Map.of(  "content", sentence2, "genre", "pets"));
        unifiedjedis.hset("doc:2".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence2).getIds()));

        String sentence3 = "Today is a sunny day";
        Map<String, String> doc3 = Map.of(  "content", sentence3, "genre", "weather");
        unifiedjedis.hset("doc:3", doc3);
        unifiedjedis.hset("doc:3".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence3).getIds()));

        // This is the test sentence
        String sentence = "That is a happy person";

        int K = 3;
        Query q = new Query("*=>[KNN $K @embedding $BLOB AS score]").
                            returnFields("content", "score").
                            addParam("K", K).
                            addParam("BLOB", longArrayToByteArray(sentenceTokenizer.encode(sentence).getIds())).
                            dialect(2);

        // Execute the query
        List<Document> docs = unifiedjedis.ftSearch("vector_idx", q).getDocuments();
        System.out.println(docs);
    }
}

确保您的Redis Stack实例(或Redis Cloud 数据库)正在运行,并在必要时设置了 REDIS_URL 环境变量。示例

export REDIS_URL=redis://user:password@host:port

默认情况下,连接会尝试连接到本地主机的 Redis Stack 实例,端口为 6379

该示例提供为一个 Maven 项目,您可以使用以下命令进行编译

mvn package

并使用以下命令执行

mvn exec:java -Dexec.mainClass=com.redis.app.App

正如所料,最小距离对应于被比较的两个句子之间最高的语义相似度。

[id:doc:1, score: 1.0, properties:[score=9301635, content=That is a very happy person], id:doc:2, score: 1.0, properties:[score=1411344, content=That is a happy dog], id:doc:3, score: 1.0, properties:[score=67178800, content=Today is a sunny day]]

参考