Java SpringAI

Spring AI LLM Gateway Pattern — Intelligent Model Routing and Fallback

Spring AI LLM Gateway Pattern — Intelligent Model Routing and Fallback

As AI usage grows, a single LLM provider becomes a single point of failure and cost inefficiency. The LLM Gateway pattern centralizes all AI calls behind a routing layer that selects the best model based on task type, cost, latency requirements, and provider availability. This tutorial builds a production-ready LLM gateway with Spring AI.

Why You Need an LLM Gateway

Without Gateway:
  Services directly call providers → scattered config → no fallback
  Team A: always uses gpt-4o (expensive)
  Team B: hardcoded claude-sonnet
  Team C: no fallback when OpenAI is down

With Gateway:
  Smart routing: "classify" → gpt-4o-mini ($0.15/M)
  Smart routing: "legal-review" → claude-opus ($15/M)
  Fallback: OpenAI down → automatically routes to Gemini
  Rate limit breached → queue request + notify caller
  Budget exceeded → switch to cheaper model automatically

Gateway Configuration

@ConfigurationProperties(prefix = "ai.gateway")
@Component
public class GatewayConfig {

    private Map<String, RouteConfig> routes = new HashMap<>();
    private List<String> fallbackOrder = List.of("openai", "anthropic", "gemini");
    private int budgetAlertThresholdUsd = 100;

    public record RouteConfig(
            String provider,
            String model,
            String fallbackProvider,
            String fallbackModel,
            int maxTokens,
            double temperature
    ) {}
}
# application.properties
ai.gateway.routes.fast.provider=openai
ai.gateway.routes.fast.model=gpt-4o-mini
ai.gateway.routes.fast.fallback-provider=anthropic
ai.gateway.routes.fast.fallback-model=claude-haiku-4-5

ai.gateway.routes.quality.provider=anthropic
ai.gateway.routes.quality.model=claude-sonnet-4-6
ai.gateway.routes.quality.fallback-provider=openai
ai.gateway.routes.quality.fallback-model=gpt-4o

ai.gateway.routes.code.provider=anthropic
ai.gateway.routes.code.model=claude-opus-4-8
ai.gateway.routes.code.fallback-provider=openai
ai.gateway.routes.code.fallback-model=gpt-4o

ai.gateway.routes.classification.provider=openai
ai.gateway.routes.classification.model=gpt-4o-mini

Multi-Provider ChatClient Configuration

@Configuration
public class MultiProviderConfig {

    @Bean("openaiClient")
    public ChatClient openaiClient(
            @Qualifier("openAiChatModel") ChatModel openAiModel) {
        return ChatClient.builder(openAiModel).build();
    }

    @Bean("anthropicClient")
    public ChatClient anthropicClient(
            @Qualifier("anthropicChatModel") ChatModel anthropicModel) {
        return ChatClient.builder(anthropicModel).build();
    }

    @Bean("geminiClient")
    public ChatClient geminiClient(
            @Qualifier("vertexAiGeminiChatModel") ChatModel geminiModel) {
        return ChatClient.builder(geminiModel).build();
    }
}

LLM Gateway Service

@Service
public class LlmGateway {

    private final Map<String, ChatClient> clients;
    private final GatewayConfig           config;
    private final ProviderHealthMonitor   healthMonitor;
    private final BudgetTracker           budgetTracker;

    public LlmGateway(
            @Qualifier("openaiClient")    ChatClient openai,
            @Qualifier("anthropicClient") ChatClient anthropic,
            @Qualifier("geminiClient")    ChatClient gemini,
            GatewayConfig config,
            ProviderHealthMonitor healthMonitor,
            BudgetTracker budgetTracker) {
        this.clients = Map.of(
                "openai",    openai,
                "anthropic", anthropic,
                "gemini",    gemini);
        this.config        = config;
        this.healthMonitor = healthMonitor;
        this.budgetTracker = budgetTracker;
    }

    public GatewayResponse route(String routeName, String systemPrompt, String userMessage) {
        GatewayConfig.RouteConfig route = config.getRoutes().get(routeName);
        if (route == null) {
            route = config.getRoutes().get("quality");  // default route
        }

        // 1. Check if primary provider is healthy and within budget
        String provider = route.provider();
        String model    = route.model();

        if (!healthMonitor.isHealthy(provider) || budgetTracker.isOverBudget(provider)) {
            System.out.printf("Primary %s unavailable, using fallback %s%n",
                    provider, route.fallbackProvider());
            provider = route.fallbackProvider();
            model    = route.fallbackModel();
        }

        return callProvider(provider, model, systemPrompt, userMessage, routeName);
    }

    private GatewayResponse callProvider(String provider, String model,
                                         String system, String user, String route) {
        ChatClient client = clients.get(provider);
        long start = System.currentTimeMillis();

        try {
            ChatResponse response = client.prompt()
                    .system(system)
                    .user(user)
                    .options(buildOptions(provider, model))
                    .call()
                    .chatResponse();

            String content = response.getResult().getOutput().getContent();
            Usage usage = response.getMetadata().getUsage();
            long latency = System.currentTimeMillis() - start;

            budgetTracker.record(provider, model, usage.getPromptTokens(),
                    usage.getGenerationTokens());
            healthMonitor.recordSuccess(provider, latency);

            return new GatewayResponse(content, provider, model, route, latency, true, null);

        } catch (Exception e) {
            healthMonitor.recordFailure(provider, e);
            System.out.printf("Provider %s failed: %s%n", provider, e.getMessage());

            // Try next provider in fallback chain
            return tryFallback(provider, model, system, user, route, e);
        }
    }

    private GatewayResponse tryFallback(String failedProvider, String failedModel,
                                        String system, String user, String route, Exception cause) {
        List<String> order = config.getFallbackOrder();
        for (String fallbackProvider : order) {
            if (fallbackProvider.equals(failedProvider)) continue;
            if (!healthMonitor.isHealthy(fallbackProvider)) continue;

            System.out.printf("Trying fallback provider: %s%n", fallbackProvider);
            String fallbackModel = "gpt-4o-mini".equals(fallbackProvider) ? "gpt-4o-mini"
                                                                           : "claude-haiku-4-5";
            try {
                return callProvider(fallbackProvider, fallbackModel, system, user, route);
            } catch (Exception e2) {
                System.out.printf("Fallback %s also failed%n", fallbackProvider);
            }
        }
        throw new AllProvidersUnavailableException("All AI providers failed", cause);
    }

    private ChatOptions buildOptions(String provider, String model) {
        return switch (provider) {
            case "openai" -> OpenAiChatOptions.builder().withModel(model).build();
            case "anthropic" -> AnthropicChatOptions.builder().withModel(model).build();
            default -> ChatOptionsBuilder.builder().build();
        };
    }
}

public record GatewayResponse(
        String content, String provider, String model,
        String route, long latencyMs, boolean success, String error) {}

Provider Health Monitor

@Component
public class ProviderHealthMonitor {

    private final Map<String, AtomicInteger>  errorCounts    = new ConcurrentHashMap<>();
    private final Map<String, AtomicLong>     lastSuccess    = new ConcurrentHashMap<>();
    private static final int  ERROR_THRESHOLD_PER_MINUTE = 5;
    private static final long CIRCUIT_OPEN_MILLIS        = 60_000;

    public boolean isHealthy(String provider) {
        int errors = errorCounts.getOrDefault(provider, new AtomicInteger(0)).get();
        if (errors >= ERROR_THRESHOLD_PER_MINUTE) {
            long lastSuccessMs = lastSuccess.getOrDefault(provider, new AtomicLong(0)).get();
            if (System.currentTimeMillis() - lastSuccessMs < CIRCUIT_OPEN_MILLIS) {
                return false;  // circuit open
            }
            errorCounts.get(provider).set(0);  // reset after timeout
        }
        return true;
    }

    public void recordSuccess(String provider, long latencyMs) {
        lastSuccess.computeIfAbsent(provider, k -> new AtomicLong())
                   .set(System.currentTimeMillis());
        errorCounts.computeIfAbsent(provider, k -> new AtomicInteger()).set(0);
    }

    public void recordFailure(String provider, Exception e) {
        errorCounts.computeIfAbsent(provider, k -> new AtomicInteger()).incrementAndGet();
    }
}

Output

// Normal routing
gateway.route("fast", "Answer briefly.", "What is Java?")
→ Provider: openai, Model: gpt-4o-mini, Latency: 423ms

// OpenAI is down
gateway.route("fast", "Answer briefly.", "What is Java?")
→ Primary openai unavailable, using fallback anthropic
→ Provider: anthropic, Model: claude-haiku-4-5, Latency: 611ms

// Code review task
gateway.route("code", "Review this code.", code)
→ Provider: anthropic, Model: claude-opus-4-8, Latency: 3201ms

Key Points

  • The Gateway pattern decouples application code from specific LLM providers — services call gateway.route("quality", ...), never specific provider APIs
  • Circuit breakers at the gateway level prevent cascading failures when a provider has an outage
  • Budget tracking per provider lets you automatically downgrade to cheaper models when spending thresholds are hit
  • Route names like "fast", "quality", and "code" give product teams a semantic API without exposing model details
  • Run the health monitor's error count reset on a schedule (@Scheduled) rather than time-since-last-success for more predictable circuit breaker behavior
Topics: Java SpringAI
← Newer Post Older Post →