Spring AI Conversation Persistence — Store Chat History in PostgreSQL or Redis
By default, InMemoryChatMemory stores conversation history in the JVM heap and loses it on restart. For production chatbots, you need persistent conversation history so users can resume sessions across restarts, and so history survives application scaling. This tutorial covers PostgreSQL-backed and Redis-backed conversation persistence.
The Problem with InMemoryChatMemory
InMemoryChatMemory problems:
✘ Lost on application restart
✘ Lost when container is replaced (Kubernetes rolling update)
✘ Not shared across multiple instances (horizontal scaling)
✘ No ability to load previous sessions for returning users
✘ Memory leak if many unique conversationIds accumulate
Production requirement:
✔ Persistent across restarts
✔ Shared across all application instances
✔ Loadable by conversationId at any time
✔ Retention policy (delete conversations older than 30 days)
Custom PostgreSQL ChatMemory
// Entity for storing messages
@Entity
@Table(name = "chat_messages")
public class ChatMessageEntity {
@Id
@GeneratedValue
private UUID id;
@Column(nullable = false)
private String conversationId;
@Column(nullable = false)
@Enumerated(EnumType.STRING)
private MessageType messageType; // USER, ASSISTANT, SYSTEM
@Column(columnDefinition = "TEXT", nullable = false)
private String content;
@Column(nullable = false)
private LocalDateTime createdAt;
private int sequenceNumber;
}
public enum MessageType { USER, ASSISTANT, SYSTEM }
// Spring Data repository
public interface ChatMessageRepository
extends JpaRepository<ChatMessageEntity, UUID> {
List<ChatMessageEntity> findByConversationIdOrderBySequenceNumberAsc(
String conversationId);
void deleteByConversationId(String conversationId);
// For retention policy
void deleteByCreatedAtBefore(LocalDateTime cutoff);
long countByConversationId(String conversationId);
}
// Implement ChatMemory backed by PostgreSQL
@Component
public class JpaChatMemory implements ChatMemory {
private final ChatMessageRepository repository;
public JpaChatMemory(ChatMessageRepository repository) {
this.repository = repository;
}
@Override
public void add(String conversationId, List<Message> messages) {
long nextSeq = repository.countByConversationId(conversationId);
List<ChatMessageEntity> entities = new ArrayList<>();
for (Message msg : messages) {
ChatMessageEntity entity = new ChatMessageEntity();
entity.setConversationId(conversationId);
entity.setContent(msg.getContent());
entity.setMessageType(toMessageType(msg));
entity.setCreatedAt(LocalDateTime.now());
entity.setSequenceNumber((int) nextSeq++);
entities.add(entity);
}
repository.saveAll(entities);
}
@Override
public List<Message> get(String conversationId, int lastN) {
List<ChatMessageEntity> rows = repository
.findByConversationIdOrderBySequenceNumberAsc(conversationId);
// Take last N messages
int skip = Math.max(0, rows.size() - lastN);
return rows.subList(skip, rows.size())
.stream()
.map(this::toMessage)
.toList();
}
@Override
public void clear(String conversationId) {
repository.deleteByConversationId(conversationId);
}
private MessageType toMessageType(Message msg) {
return switch (msg) {
case UserMessage m -> MessageType.USER;
case AssistantMessage m -> MessageType.ASSISTANT;
default -> MessageType.SYSTEM;
};
}
private Message toMessage(ChatMessageEntity entity) {
return switch (entity.getMessageType()) {
case USER -> new UserMessage(entity.getContent());
case ASSISTANT -> new AssistantMessage(entity.getContent());
case SYSTEM -> new SystemMessage(entity.getContent());
};
}
}
Wire JpaChatMemory into ChatClient
@Service
public class PersistentChatService {
private final ChatClient chatClient;
private final JpaChatMemory memory;
public PersistentChatService(ChatClient.Builder builder, JpaChatMemory memory) {
this.memory = memory;
this.chatClient = builder
.defaultSystem("You are a helpful Java assistant.")
.defaultAdvisors(new MessageChatMemoryAdvisor(memory))
.build();
}
public String chat(String conversationId, String userMessage) {
return chatClient.prompt()
.user(userMessage)
.advisors(a -> a.param(
MessageChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY,
conversationId))
.call()
.content();
}
public List<Message> getHistory(String conversationId) {
return memory.get(conversationId, Integer.MAX_VALUE);
}
}
Redis-Backed ChatMemory (Faster for High Volume)
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
@Component
public class RedisChatMemory implements ChatMemory {
private final RedisTemplate<String, String> redis;
private final ObjectMapper objectMapper;
private static final Duration TTL = Duration.ofDays(7);
public RedisChatMemory(RedisTemplate<String, String> redis, ObjectMapper mapper) {
this.redis = redis;
this.objectMapper = mapper;
}
private String key(String conversationId) {
return "chat:messages:" + conversationId;
}
@Override
public void add(String conversationId, List<Message> messages) {
String listKey = key(conversationId);
for (Message msg : messages) {
try {
String json = objectMapper.writeValueAsString(
Map.of("type", msg.getClass().getSimpleName(),
"content", msg.getContent()));
redis.opsForList().rightPush(listKey, json);
} catch (Exception e) {
throw new RuntimeException("Failed to save message to Redis", e);
}
}
redis.expire(listKey, TTL); // reset TTL on each update
}
@Override
public List<Message> get(String conversationId, int lastN) {
String listKey = key(conversationId);
long size = Optional.ofNullable(redis.opsForList().size(listKey)).orElse(0L);
long start = Math.max(0, size - lastN);
List<String> jsonList = redis.opsForList().range(listKey, start, -1);
if (jsonList == null) return List.of();
return jsonList.stream()
.map(json -> {
try {
Map<String, String> map = objectMapper.readValue(json, Map.class);
return switch (map.get("type")) {
case "UserMessage" -> (Message) new UserMessage(map.get("content"));
case "AssistantMessage" -> new AssistantMessage(map.get("content"));
default -> new SystemMessage(map.get("content"));
};
} catch (Exception e) {
return (Message) new SystemMessage("");
}
})
.toList();
}
@Override
public void clear(String conversationId) {
redis.delete(key(conversationId));
}
}
Conversation History API
@RestController
@RequestMapping("/chat")
public class ChatController {
private final PersistentChatService chatService;
@PostMapping
public ChatResponse chat(@RequestBody ChatRequest req) {
String reply = chatService.chat(req.sessionId(), req.message());
return new ChatResponse(reply, req.sessionId());
}
@GetMapping("/{sessionId}/history")
public List<MessageDto> history(@PathVariable String sessionId) {
return chatService.getHistory(sessionId)
.stream()
.map(msg -> new MessageDto(
msg.getClass().getSimpleName(),
msg.getContent()))
.toList();
}
@DeleteMapping("/{sessionId}")
public void clearHistory(@PathVariable String sessionId) {
chatService.clearHistory(sessionId);
}
}
record ChatRequest(String sessionId, String message) {}
record ChatResponse(String reply, String sessionId) {}
record MessageDto(String role, String content) {}
Output
// POST /chat { "sessionId": "user-123", "message": "What is Spring AI?" }
{ "reply": "Spring AI is a framework...", "sessionId": "user-123" }
// POST /chat { "sessionId": "user-123", "message": "How does it compare to LangChain?" }
// (AI remembers first message — persistent context)
{ "reply": "Compared to what I explained earlier, LangChain is a Python framework...", ... }
// GET /chat/user-123/history
[
{"role": "UserMessage", "content": "What is Spring AI?"},
{"role": "AssistantMessage", "content": "Spring AI is a framework..."},
{"role": "UserMessage", "content": "How does it compare to LangChain?"},
{"role": "AssistantMessage", "content": "Compared to what I explained earlier..."}
]
Key Points
- PostgreSQL-backed memory is the right choice for most applications — it supports complex queries, foreign keys to user tables, and long-term retention policies
- Redis-backed memory is faster (microsecond reads vs millisecond JPA reads) but requires a Redis instance and loses history on TTL expiry
- Always reset the Redis TTL on each
add()call so active conversations don't expire mid-session - Limit
get()to the last N messages (e.g., 20) to avoid sending the entire conversation history as tokens — it adds cost and pushes relevant context out of the context window - Add a scheduled job to delete old conversations:
@Scheduled(cron="0 0 3 * * ?") repository.deleteByCreatedAtBefore(LocalDateTime.now().minusDays(30))
Comments