Spring AI Advanced RAG — Hybrid Search, Re-ranking, and Query Expansion
Basic RAG retrieves the top-K most similar chunks and injects them into the prompt. Advanced RAG uses multiple retrieval strategies, re-ranks results, and expands queries to dramatically improve answer quality. This tutorial covers hybrid search (semantic + keyword), query expansion, and cross-encoder re-ranking with Spring AI.
Why Basic RAG Falls Short
Problem 1: Semantic search misses exact matches
Query: "NullPointerException in getUserById"
Vector search: returns general null handling docs (semantically similar but wrong)
Keyword search: returns exact method containing the bug
Problem 2: Top-K retrieval may miss relevant chunks
Query needs 3 specific chunks that are semantically distant from each other
Basic top-5 retrieval gets 2/3 correct
Solution: Hybrid search (semantic + BM25 keyword) + re-ranking
Hybrid Search Implementation
@Service
public class HybridSearchService {
private final VectorStore vectorStore; // semantic search
private final JdbcTemplate jdbc; // keyword search via PostgreSQL full-text
public HybridSearchService(VectorStore vectorStore, JdbcTemplate jdbc) {
this.vectorStore = vectorStore;
this.jdbc = jdbc;
}
public List<Document> hybridSearch(String query, int topK) {
// 1. Semantic search
List<Document> semanticResults = vectorStore.similaritySearch(
SearchRequest.query(query)
.withTopK(topK * 2) // get extra candidates for re-ranking
.withSimilarityThreshold(0.5)
);
// 2. Keyword search via PostgreSQL full-text search
String sql = """
SELECT id, content, metadata
FROM vector_store
WHERE to_tsvector('english', content) @@ plainto_tsquery('english', ?)
LIMIT ?
""";
List<Document> keywordResults = jdbc.query(sql,
(rs, row) -> new Document(rs.getString("content")),
query, topK * 2);
// 3. Merge and deduplicate
Map<String, Document> merged = new LinkedHashMap<>();
semanticResults.forEach(d -> merged.put(d.getId(), d));
keywordResults.forEach(d -> merged.putIfAbsent(d.getId(), d));
return new ArrayList<>(merged.values()).subList(0, Math.min(topK, merged.size()));
}
}
Query Expansion — Multiple Reformulations
@Service
public class QueryExpansionService {
private final ChatClient chatClient;
private final VectorStore vectorStore;
public QueryExpansionService(ChatClient.Builder builder, VectorStore vectorStore) {
this.chatClient = builder.build();
this.vectorStore = vectorStore;
}
// Generate multiple reformulations of the user's question
private List<String> expandQuery(String originalQuery) {
String expanded = chatClient.prompt()
.user("""
Generate 3 different reformulations of this search query.
Output only the queries, one per line, no numbering or explanation.
Query: %s
""".formatted(originalQuery))
.call()
.content();
List<String> queries = new ArrayList<>();
queries.add(originalQuery);
Arrays.stream(expanded.split("\n"))
.map(String::trim)
.filter(s -> !s.isBlank())
.forEach(queries::add);
return queries;
}
public List<Document> searchWithExpansion(String query, int topK) {
List<String> queries = expandQuery(query);
System.out.println("Expanded queries: " + queries);
// Search with each query variant
Map<String, Document> allResults = new LinkedHashMap<>();
for (String q : queries) {
vectorStore.similaritySearch(
SearchRequest.query(q).withTopK(topK))
.forEach(doc -> allResults.put(doc.getId(), doc));
}
return allResults.values().stream()
.limit(topK)
.toList();
}
}
Output of Query Expansion
Original query: "How do I handle database connection errors in Spring?"
Expanded queries:
1. "How do I handle database connection errors in Spring?"
2. "Spring Boot database connection failure exception handling"
3. "Spring DataAccessException connection refused retry"
4. "Handle SQLException DataSource unavailable Spring application"
Result: 4x more coverage of the knowledge base
Re-ranking with a Cross-Encoder
@Service
public class RerankingService {
private final ChatClient chatClient;
public RerankingService(ChatClient.Builder builder) {
this.chatClient = builder.build();
}
// Use AI to score each document's relevance to the query
public List<Document> rerank(String query, List<Document> candidates, int topK) {
StringBuilder scoringPrompt = new StringBuilder();
scoringPrompt.append("Rate each document's relevance to the query from 0-10.\n");
scoringPrompt.append("Query: ").append(query).append("\n\n");
for (int i = 0; i < candidates.size(); i++) {
scoringPrompt.append("Document ").append(i).append(":\n");
scoringPrompt.append(candidates.get(i).getContent(), 0,
Math.min(300, candidates.get(i).getContent().length()));
scoringPrompt.append("\n\n");
}
scoringPrompt.append("Output ONLY a JSON array of scores: [8, 3, 9, 1, 6]");
String scoresJson = chatClient.prompt()
.user(scoringPrompt.toString())
.call()
.content()
.trim();
// Parse scores and sort
int[] scores = parseScores(scoresJson, candidates.size());
List<Map.Entry<Integer, Document>> scored = new ArrayList<>();
for (int i = 0; i < candidates.size(); i++) {
scored.add(Map.entry(scores[i], candidates.get(i)));
}
return scored.stream()
.sorted((a, b) -> b.getKey() - a.getKey()) // highest score first
.limit(topK)
.map(Map.Entry::getValue)
.toList();
}
private int[] parseScores(String json, int count) {
// Extract numbers from JSON array string
int[] scores = new int[count];
try {
String numbers = json.replaceAll("[^0-9,]", "");
String[] parts = numbers.split(",");
for (int i = 0; i < Math.min(scores.length, parts.length); i++) {
scores[i] = Integer.parseInt(parts[i].trim());
}
} catch (Exception e) {
Arrays.fill(scores, 5); // default score on parse failure
}
return scores;
}
}
Full Advanced RAG Pipeline
@Service
public class AdvancedRagService {
private final HybridSearchService hybridSearch;
private final QueryExpansionService expansion;
private final RerankingService reranker;
private final ChatClient chatClient;
public String answer(String question) {
// 1. Expand query → multiple reformulations
List<Document> candidates = expansion.searchWithExpansion(question, 20);
// 2. Re-rank → select most relevant top-5
List<Document> topDocs = reranker.rerank(question, candidates, 5);
// 3. Build context string
String context = topDocs.stream()
.map(Document::getContent)
.collect(Collectors.joining("\n\n---\n\n"));
// 4. Generate answer
return chatClient.prompt()
.system("Answer only from the provided context. Be precise.")
.user("Context:\n" + context + "\n\nQuestion: " + question)
.call()
.content();
}
}
Key Points
- Hybrid search (semantic + keyword) improves recall by 20-40% compared to semantic-only search
- Query expansion with 3-4 reformulations catches relevant chunks that a single query would miss
- Re-ranking adds latency (one extra AI call) but significantly improves precision — use it when answer quality is critical
- Retrieve more candidates (2x-3x topK) before re-ranking — gives the re-ranker enough material to work with
- Advanced RAG roughly doubles total token usage — benchmark the cost vs quality improvement for your use case before enabling all techniques
Comments