如何使用 Jedis 客户端库在 Java 中执行向量搜索?

上次更新时间:2024 年 4 月 20 日

问题

如何使用 Jedis 客户端库在 Java 中执行向量搜索?

答案

创建一个 Java Maven 项目(查看 说明 构建一个脚手架项目),并包含以下依赖项(指定所需的版本)。

    <dependency>
      <groupId>redis.clients</groupId>
      <artifactId>jedis</artifactId>
      <version>5.0.1</version>
    </dependency>
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.24.0</version>
    </dependency>
    <dependency>
      <groupId>ai.djl.huggingface</groupId>
      <artifactId>tokenizers</artifactId>
      <version>0.24.0</version>
    </dependency>

本示例将三个句子(“那是一个非常快乐的人”,“那是一条快乐的狗”,“今天是一个阳光明媚的日子”)存储为 Redis 哈希,并找到测试句子“那是一个非常快乐的人”与模型句子之间的相似度。向量搜索配置为返回三个结果(KNN 3)。

package com.redis.app;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.UnifiedJedis;
import redis.clients.jedis.search.*;
import redis.clients.jedis.search.schemafields.*;
import redis.clients.jedis.HostAndPort;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;

import java.util.HashMap;
import java.util.List;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;


public class App {
    public static byte[] floatArrayToByteArray(float[] input) {
        byte[] bytes = new byte[Float.BYTES * input.length];
        ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(input);
        return bytes;
    }

    public static byte[] longArrayToByteArray(long[] input) {
        return floatArrayToByteArray(longArrayToFloatArray(input));
    }

    public static float[] longArrayToFloatArray(long[] input) {
        float[] floats = new float[input.length];
        for (int i = 0; i < input.length; i++) {
            floats[i] = input[i];
        }
        return floats;
    }

    public static void main(String[] args) {
        // Connect to Redis
        UnifiedJedis unifiedjedis = new UnifiedJedis(System.getenv().getOrDefault("REDIS_URL", "redis://localhost:6379"));

        // Create the index
        IndexDefinition definition = new IndexDefinition().setPrefixes(new String[]{"doc:"});
        Map<String, Object> attr = new HashMap<>();
        attr.put("TYPE", "FLOAT32");
        attr.put("DIM", 768);
        attr.put("DISTANCE_METRIC", "L2");
        attr.put("INITIAL_CAP", 3);
        Schema schema = new Schema().addTextField("content", 1).addTagField("genre").addHNSWVectorField("embedding", attr);                      

        // Catch exceptions if the index exists
        try {
            unifiedjedis.ftCreate("vector_idx", IndexOptions.defaultOptions().setDefinition(definition), schema);
        }
        catch(Exception e) {
            System.out.println(e.getMessage());
        }

        // Create the embedding model
        Map<String, String> options = Map.of("maxLength", "768",  "modelMaxLength", "768");
        HuggingFaceTokenizer sentenceTokenizer = HuggingFaceTokenizer.newInstance("sentence-transformers/all-mpnet-base-v2", options);

        // Train with sentences
        String sentence1 = "That is a very happy person";
        unifiedjedis.hset("doc:1", Map.of(  "content", sentence1, "genre", "persons"));
        unifiedjedis.hset("doc:1".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence1).getIds()));

        String sentence2 = "That is a happy dog";
        unifiedjedis.hset("doc:2", Map.of(  "content", sentence2, "genre", "pets"));
        unifiedjedis.hset("doc:2".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence2).getIds()));

        String sentence3 = "Today is a sunny day";
        Map<String, String> doc3 = Map.of(  "content", sentence3, "genre", "weather");
        unifiedjedis.hset("doc:3", doc3);
        unifiedjedis.hset("doc:3".getBytes(), "embedding".getBytes(), longArrayToByteArray(sentenceTokenizer.encode(sentence3).getIds()));

        // This is the test sentence
        String sentence = "That is a happy person";

        int K = 3;
        Query q = new Query("*=>[KNN $K @embedding $BLOB AS score]").
                            returnFields("content", "score").
                            addParam("K", K).
                            addParam("BLOB", longArrayToByteArray(sentenceTokenizer.encode(sentence).getIds())).
                            dialect(2);

        // Execute the query
        List<Document> docs = unifiedjedis.ftSearch("vector_idx", q).getDocuments();
        System.out.println(docs);
    }
}

确保您的 Redis Stack 实例(或 Redis Cloud 数据库)正在运行,并且您已设置 REDIS_URL 环境变量(如果需要)。示例

export REDIS_URL=redis://user:password@host:port

默认情况下,连接尝试连接到端口 6379 上的本地主机 Redis Stack 实例。

本示例提供为 Maven 项目,您可以使用以下命令编译:

mvn package

并使用以下命令执行:

mvn exec:java -Dexec.mainClass=com.redis.app.App

如预期的那样,最小距离对应于两个句子之间语义相似度最高。

[id:doc:1, score: 1.0, properties:[score=9301635, content=That is a very happy person], id:doc:2, score: 1.0, properties:[score=1411344, content=That is a happy dog], id:doc:3, score: 1.0, properties:[score=67178800, content=Today is a sunny day]]

参考资料