Java SpringAI

Spring AI with Apache Kafka — Async AI Processing at Scale

Spring AI with Apache Kafka — Async AI Processing at Scale

Direct AI API calls block the HTTP thread for 2-30 seconds. For high-throughput applications processing thousands of documents, user events, or content moderation requests, you need async AI processing through a message queue. This tutorial shows how to build a Kafka-based AI pipeline with Spring AI and Spring Kafka.

Architecture — Async AI Processing

HTTP Request
     │
     ▼
[Producer Service]  ──→  Kafka Topic: ai.requests
                                  │
                         [AI Processor Consumer]
                                  │
                          Spring AI ChatClient
                                  │
                         Kafka Topic: ai.responses
                                  │
                    [Notification Service / WebSocket push]

Maven Dependencies

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<dependency>
    <groupId>org.springframework.kafka</groupId>
    <artifactId>spring-kafka</artifactId>
</dependency>

application.properties

spring.ai.openai.api-key=${OPENAI_API_KEY}

# Kafka
spring.kafka.bootstrap-servers=localhost:9092
spring.kafka.consumer.group-id=ai-processors
spring.kafka.consumer.auto-offset-reset=earliest
spring.kafka.consumer.key-deserializer=org.apache.kafka.common.serialization.StringDeserializer
spring.kafka.consumer.value-deserializer=org.springframework.kafka.support.serializer.JsonDeserializer
spring.kafka.producer.key-serializer=org.apache.kafka.common.serialization.StringSerializer
spring.kafka.producer.value-serializer=org.springframework.kafka.support.serializer.JsonSerializer
spring.kafka.consumer.properties.spring.json.trusted.packages=*

Message DTOs

public record AiRequest(
        String requestId,
        String userId,
        String taskType,    // "summarize", "classify", "generate", "analyze"
        String content,
        Map<String, String> metadata
) {}

public record AiResponse(
        String requestId,
        String userId,
        String taskType,
        String result,
        boolean success,
        String errorMessage,
        long processingTimeMs
) {}

Producer — Submit AI Tasks

@Service
public class AiTaskProducer {

    private final KafkaTemplate<String, AiRequest> kafkaTemplate;

    public AiTaskProducer(KafkaTemplate<String, AiRequest> kafkaTemplate) {
        this.kafkaTemplate = kafkaTemplate;
    }

    public String submitTask(String userId, String taskType, String content) {
        String requestId = UUID.randomUUID().toString();

        AiRequest request = new AiRequest(
                requestId, userId, taskType, content,
                Map.of("submittedAt", Instant.now().toString()));

        kafkaTemplate.send("ai.requests", requestId, request);
        System.out.println("Submitted AI task: " + requestId);

        return requestId;  // return immediately — task is async
    }
}

Consumer — Process AI Tasks

@Service
public class AiTaskConsumer {

    private final ChatClient chatClient;
    private final KafkaTemplate<String, AiResponse> responseTemplate;

    public AiTaskConsumer(ChatClient.Builder builder,
                          KafkaTemplate<String, AiResponse> responseTemplate) {
        this.chatClient       = builder.build();
        this.responseTemplate = responseTemplate;
    }

    @KafkaListener(topics = "ai.requests", concurrency = "5")  // 5 parallel consumers
    public void processAiRequest(AiRequest request) {
        long start = System.currentTimeMillis();
        System.out.println("Processing: " + request.requestId() +
                " [" + request.taskType() + "]");

        try {
            String result = processTask(request);

            AiResponse response = new AiResponse(
                    request.requestId(), request.userId(),
                    request.taskType(), result, true, null,
                    System.currentTimeMillis() - start);

            responseTemplate.send("ai.responses", request.requestId(), response);

        } catch (Exception e) {
            AiResponse errorResponse = new AiResponse(
                    request.requestId(), request.userId(),
                    request.taskType(), null, false, e.getMessage(),
                    System.currentTimeMillis() - start);

            responseTemplate.send("ai.responses", request.requestId(), errorResponse);
        }
    }

    private String processTask(AiRequest request) {
        return switch (request.taskType()) {
            case "summarize" -> summarize(request.content());
            case "classify"  -> classify(request.content());
            case "analyze"   -> analyze(request.content());
            default          -> chatClient.prompt().user(request.content()).call().content();
        };
    }

    private String summarize(String content) {
        return chatClient.prompt()
                .system("Summarize the following content in 3 bullet points.")
                .user(content)
                .call()
                .content();
    }

    private String classify(String content) {
        return chatClient.prompt()
                .system("Classify this content. Output ONLY one label: positive/negative/neutral/spam")
                .user(content)
                .call()
                .content();
    }

    private String analyze(String content) {
        return chatClient.prompt()
                .system("Analyze this Java code for bugs, security issues, and improvements.")
                .user(content)
                .call()
                .content();
    }
}

Response Consumer + Status Tracker

@Service
public class AiResponseConsumer {

    // In-memory status map (use Redis for multi-instance deployments)
    private final Map<String, AiResponse> completedTasks = new ConcurrentHashMap<>();

    @KafkaListener(topics = "ai.responses")
    public void handleResponse(AiResponse response) {
        completedTasks.put(response.requestId(), response);
        System.out.printf("Completed %s in %dms%n",
                response.requestId(), response.processingTimeMs());
    }

    public Optional<AiResponse> getResult(String requestId) {
        return Optional.ofNullable(completedTasks.get(requestId));
    }
}

REST Controller — Submit and Poll

@RestController
@RequestMapping("/ai-tasks")
public class AiTaskController {

    private final AiTaskProducer   producer;
    private final AiResponseConsumer consumer;

    @PostMapping
    public Map<String, String> submit(@RequestBody TaskRequest req) {
        String requestId = producer.submitTask(
                req.userId(), req.taskType(), req.content());
        return Map.of("requestId", requestId, "status", "queued");
    }

    @GetMapping("/{requestId}")
    public ResponseEntity<?> getResult(@PathVariable String requestId) {
        return consumer.getResult(requestId)
                .map(ResponseEntity::ok)
                .orElse(ResponseEntity.accepted().build());  // 202 = still processing
    }
}

record TaskRequest(String userId, String taskType, String content) {}

Output

// POST /ai-tasks { "userId": "u1", "taskType": "summarize", "content": "Spring AI is..." }
{ "requestId": "abc-123", "status": "queued" }    ← responds in <1ms

// GET /ai-tasks/abc-123  (immediately)
HTTP 202 Accepted                                  ← still processing

// GET /ai-tasks/abc-123  (after ~3 seconds)
{
  "requestId": "abc-123",
  "userId": "u1",
  "taskType": "summarize",
  "result": "• Spring AI provides unified API across LLM providers\n• RAG support built-in\n• Integrates with Spring Boot ecosystem",
  "success": true,
  "processingTimeMs": 2847
}

Consumer logs:
Processing: abc-123 [summarize]
Completed abc-123 in 2847ms

Key Points

  • Set concurrency = "5" on @KafkaListener to process 5 AI requests in parallel — tune based on your AI API rate limits
  • Always send both success and error responses to the response topic — consumers should always receive exactly one response per request
  • For multi-instance deployments, replace the in-memory completedTasks map with Redis so any instance can answer the status poll
  • Add dead-letter topics (ai.requests.DLT) to capture messages that fail after all retry attempts — inspect them to detect systematic failures
  • Use Kafka's partition key (the requestId) to ensure that request and response messages for the same requestId go to the same consumer if ordering matters
Topics: Java SpringAI
← Newer Post Older Post →