Java SpringAI

Spring AI Advanced RAG — Hybrid Search, Re-ranking, and Query Expansion

Spring AI Advanced RAG — Hybrid Search, Re-ranking, and Query Expansion

Basic RAG retrieves the top-K most similar chunks and injects them into the prompt. Advanced RAG uses multiple retrieval strategies, re-ranks results, and expands queries to dramatically improve answer quality. This tutorial covers hybrid search (semantic + keyword), query expansion, and cross-encoder re-ranking with Spring AI.

Why Basic RAG Falls Short

Problem 1: Semantic search misses exact matches
  Query: "NullPointerException in getUserById"
  Vector search: returns general null handling docs (semantically similar but wrong)
  Keyword search: returns exact method containing the bug

Problem 2: Top-K retrieval may miss relevant chunks
  Query needs 3 specific chunks that are semantically distant from each other
  Basic top-5 retrieval gets 2/3 correct

Solution: Hybrid search (semantic + BM25 keyword) + re-ranking

Hybrid Search Implementation

@Service
public class HybridSearchService {

    private final VectorStore  vectorStore;    // semantic search
    private final JdbcTemplate jdbc;           // keyword search via PostgreSQL full-text

    public HybridSearchService(VectorStore vectorStore, JdbcTemplate jdbc) {
        this.vectorStore = vectorStore;
        this.jdbc        = jdbc;
    }

    public List<Document> hybridSearch(String query, int topK) {
        // 1. Semantic search
        List<Document> semanticResults = vectorStore.similaritySearch(
                SearchRequest.query(query)
                        .withTopK(topK * 2)   // get extra candidates for re-ranking
                        .withSimilarityThreshold(0.5)
        );

        // 2. Keyword search via PostgreSQL full-text search
        String sql = """
                SELECT id, content, metadata
                FROM vector_store
                WHERE to_tsvector('english', content) @@ plainto_tsquery('english', ?)
                LIMIT ?
                """;

        List<Document> keywordResults = jdbc.query(sql,
                (rs, row) -> new Document(rs.getString("content")),
                query, topK * 2);

        // 3. Merge and deduplicate
        Map<String, Document> merged = new LinkedHashMap<>();
        semanticResults.forEach(d -> merged.put(d.getId(), d));
        keywordResults.forEach(d  -> merged.putIfAbsent(d.getId(), d));

        return new ArrayList<>(merged.values()).subList(0, Math.min(topK, merged.size()));
    }
}

Query Expansion — Multiple Reformulations

@Service
public class QueryExpansionService {

    private final ChatClient  chatClient;
    private final VectorStore vectorStore;

    public QueryExpansionService(ChatClient.Builder builder, VectorStore vectorStore) {
        this.chatClient  = builder.build();
        this.vectorStore = vectorStore;
    }

    // Generate multiple reformulations of the user's question
    private List<String> expandQuery(String originalQuery) {
        String expanded = chatClient.prompt()
                .user("""
                      Generate 3 different reformulations of this search query.
                      Output only the queries, one per line, no numbering or explanation.

                      Query: %s
                      """.formatted(originalQuery))
                .call()
                .content();

        List<String> queries = new ArrayList<>();
        queries.add(originalQuery);
        Arrays.stream(expanded.split("\n"))
              .map(String::trim)
              .filter(s -> !s.isBlank())
              .forEach(queries::add);

        return queries;
    }

    public List<Document> searchWithExpansion(String query, int topK) {
        List<String> queries = expandQuery(query);
        System.out.println("Expanded queries: " + queries);

        // Search with each query variant
        Map<String, Document> allResults = new LinkedHashMap<>();
        for (String q : queries) {
            vectorStore.similaritySearch(
                    SearchRequest.query(q).withTopK(topK))
                    .forEach(doc -> allResults.put(doc.getId(), doc));
        }

        return allResults.values().stream()
                .limit(topK)
                .toList();
    }
}

Output of Query Expansion

Original query: "How do I handle database connection errors in Spring?"

Expanded queries:
1. "How do I handle database connection errors in Spring?"
2. "Spring Boot database connection failure exception handling"
3. "Spring DataAccessException connection refused retry"
4. "Handle SQLException DataSource unavailable Spring application"

Result: 4x more coverage of the knowledge base

Re-ranking with a Cross-Encoder

@Service
public class RerankingService {

    private final ChatClient chatClient;

    public RerankingService(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    // Use AI to score each document's relevance to the query
    public List<Document> rerank(String query, List<Document> candidates, int topK) {
        StringBuilder scoringPrompt = new StringBuilder();
        scoringPrompt.append("Rate each document's relevance to the query from 0-10.\n");
        scoringPrompt.append("Query: ").append(query).append("\n\n");

        for (int i = 0; i < candidates.size(); i++) {
            scoringPrompt.append("Document ").append(i).append(":\n");
            scoringPrompt.append(candidates.get(i).getContent(), 0,
                    Math.min(300, candidates.get(i).getContent().length()));
            scoringPrompt.append("\n\n");
        }
        scoringPrompt.append("Output ONLY a JSON array of scores: [8, 3, 9, 1, 6]");

        String scoresJson = chatClient.prompt()
                .user(scoringPrompt.toString())
                .call()
                .content()
                .trim();

        // Parse scores and sort
        int[] scores = parseScores(scoresJson, candidates.size());
        List<Map.Entry<Integer, Document>> scored = new ArrayList<>();
        for (int i = 0; i < candidates.size(); i++) {
            scored.add(Map.entry(scores[i], candidates.get(i)));
        }

        return scored.stream()
                .sorted((a, b) -> b.getKey() - a.getKey())  // highest score first
                .limit(topK)
                .map(Map.Entry::getValue)
                .toList();
    }

    private int[] parseScores(String json, int count) {
        // Extract numbers from JSON array string
        int[] scores = new int[count];
        try {
            String numbers = json.replaceAll("[^0-9,]", "");
            String[] parts = numbers.split(",");
            for (int i = 0; i < Math.min(scores.length, parts.length); i++) {
                scores[i] = Integer.parseInt(parts[i].trim());
            }
        } catch (Exception e) {
            Arrays.fill(scores, 5);  // default score on parse failure
        }
        return scores;
    }
}

Full Advanced RAG Pipeline

@Service
public class AdvancedRagService {

    private final HybridSearchService  hybridSearch;
    private final QueryExpansionService expansion;
    private final RerankingService      reranker;
    private final ChatClient            chatClient;

    public String answer(String question) {
        // 1. Expand query → multiple reformulations
        List<Document> candidates = expansion.searchWithExpansion(question, 20);

        // 2. Re-rank → select most relevant top-5
        List<Document> topDocs = reranker.rerank(question, candidates, 5);

        // 3. Build context string
        String context = topDocs.stream()
                .map(Document::getContent)
                .collect(Collectors.joining("\n\n---\n\n"));

        // 4. Generate answer
        return chatClient.prompt()
                .system("Answer only from the provided context. Be precise.")
                .user("Context:\n" + context + "\n\nQuestion: " + question)
                .call()
                .content();
    }
}

Key Points

  • Hybrid search (semantic + keyword) improves recall by 20-40% compared to semantic-only search
  • Query expansion with 3-4 reformulations catches relevant chunks that a single query would miss
  • Re-ranking adds latency (one extra AI call) but significantly improves precision — use it when answer quality is critical
  • Retrieve more candidates (2x-3x topK) before re-ranking — gives the re-ranker enough material to work with
  • Advanced RAG roughly doubles total token usage — benchmark the cost vs quality improvement for your use case before enabling all techniques
Topics: Java SpringAI
← Newer Post Older Post →