Building Streaming AI Responses with Axum and SSE
How we stream Claude and OpenAI responses token-by-token to the browser using Server-Sent Events in Rust.
Building Streaming AI Responses with Axum and SSE
Non-streaming AI responses feel slow. A 300-token answer takes 2–3 seconds to generate — that's 2–3 seconds of the user staring at a spinner. Streaming delivers tokens as they're generated, making the experience feel instantaneous even for long responses.
Here's how we built streaming in the aiassist.chat backend using Axum and Server-Sent Events — including the full handler, error handling, reconnection, and testing.
Why SSE over WebSockets, and Why Not Polling
There are three approaches to delivering tokens to the client: WebSockets, SSE, or polling.
Polling (repeated GET /response?conversation_id=X) is the worst option. It introduces latency equal to your poll interval, hammers your server with unnecessary requests, and the implementation complexity rivals SSE without any of the benefits. If you're polling for AI responses, migrate.
WebSockets are bidirectional. A chat widget sending one message and receiving one streaming response doesn't need bidirectionality — it's a fundamentally one-way flow after the message is sent. WebSockets also require a connection upgrade, which some corporate proxies and load balancers block or reset. Managing WebSocket connection state across reconnects adds meaningful complexity.
SSE is simpler, works over standard HTTP/2, doesn't require a connection upgrade, and has built-in reconnection semantics via the Last-Event-ID header. For a chat widget embedded on third-party sites — where you have no control over the network infrastructure between you and the visitor — SSE is the right primitive.
SSE isn't just simpler than WebSockets for this use case — it's more reliable in the hostile network environments that embedded widgets encounter.
The Full Axum Handler
Axum's Sse response type pairs with tokio::sync::mpsc channels. Here's the complete handler with authentication, rate limiting, error propagation, and token budget management:
use axum::{
extract::{Extension, Json},
response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::{Stream, StreamExt};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use std::convert::Infallible;
#[derive(Deserialize)]
pub struct ChatStreamRequest {
pub message: String,
pub conversation_id: Uuid,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
Token { text: String },
Done { conversation_id: Uuid },
Error { message: String },
}
pub async fn chat_stream_handler(
Extension(tenant): Extension<AuthenticatedTenant>,
Extension(state): Extension<AppState>,
Json(req): Json<ChatStreamRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, AppError> {
// Check rate limit before allocating the channel
state
.rate_limiter
.check(tenant.id, 60, tenant.plan.requests_per_minute())
.await
.map_err(|_| AppError::RateLimitExceeded)?;
// Check token budget
let remaining_budget = state
.billing
.get_remaining_budget(tenant.id)
.await?;
if remaining_budget == 0 {
return Err(AppError::BudgetExhausted);
}
let (tx, rx) = mpsc::channel::<StreamEvent>(32);
let state = state.clone();
let tenant_id = tenant.id;
tokio::spawn(async move {
let result = run_rag_and_stream(
&state,
tenant_id,
&req.message,
req.conversation_id,
tx.clone(),
)
.await;
if let Err(e) = result {
let _ = tx
.send(StreamEvent::Error {
message: e.to_string(),
})
.await;
}
});
let event_stream = ReceiverStream::new(rx).map(|evt| {
let data = serde_json::to_string(&evt).unwrap_or_default();
Ok(Event::default().data(data))
});
Ok(Sse::new(event_stream).keep_alive(
KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text("ping"),
))
}
The keep_alive with a 15-second interval is essential. Without it, connections behind proxies or load balancers with shorter idle timeouts (AWS ALB defaults to 60 seconds, but some corporate proxies are much shorter) will terminate before a long response finishes. The ping maintains the connection without sending visible content to the client.
Redis-Based Session State for Streaming
For multi-instance deployments, streaming state needs to be coordinated across instances. If a user reconnects to a different backend instance, that instance needs to know the conversation context.
We store active streaming sessions in Redis with a short TTL:
pub struct StreamingSession {
pub conversation_id: Uuid,
pub tenant_id: Uuid,
pub started_at: i64,
pub tokens_emitted: u32,
pub last_event_id: Option<String>,
}
pub async fn create_streaming_session(
redis: &RedisPool,
session: &StreamingSession,
) -> Result<()> {
let key = format!("stream_session:{}", session.conversation_id);
let value = serde_json::to_string(session)?;
// TTL of 10 minutes — longer than any realistic streaming response
redis.set_ex(&key, &value, 600).await?;
Ok(())
}
When the stream completes or errors, we delete the session key. If it expires naturally, the conversation is considered complete. This gives us clean state management without any background cleanup jobs.
Reconnection Handling When SSE Drops
The SSE protocol has a built-in reconnection mechanism: the browser automatically reconnects after a configurable delay (default 3 seconds) and sends the Last-Event-ID header with the ID of the last event it received.
To support seamless reconnection, emit event IDs on your token events:
// Emit tokens with sequential IDs
let event = Event::default()
.id(token_index.to_string())
.data(serde_json::to_string(&StreamEvent::Token { text: token })?);
On reconnection, your handler checks the Last-Event-ID and can resume from that point:
pub async fn chat_stream_handler(
headers: HeaderMap,
// ... other extractors
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, AppError> {
let last_event_id: Option<u32> = headers
.get("last-event-id")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok());
// If reconnecting mid-stream, replay tokens from last_event_id + 1
// (stored in your Redis session)
// ...
}
In practice, most streaming responses complete in under 10 seconds. Reconnection mid-stream is rare. But implementing it correctly prevents a poor UX where a network hiccup produces a truncated response with no recovery path.
Testing Streaming Endpoints
Streaming endpoints are the most-skipped test surface in most codebases. Here's how we test them:
#[tokio::test]
async fn test_chat_stream_emits_tokens_and_done_event() {
let app = build_test_app().await;
let tenant = create_test_tenant(&app).await;
let response = app
.authenticated_request(tenant.api_key)
.post("/chat/stream")
.json(&json!({
"message": "What is your refund policy?",
"conversation_id": Uuid::new_v4(),
}))
.send()
.await
.expect("request failed");
assert_eq!(response.status(), 200);
assert_eq!(
response.headers()["content-type"],
"text/event-stream"
);
// Collect all SSE events
let body = response.text().await.unwrap();
let events: Vec<StreamEvent> = parse_sse_events(&body);
// At least one token event
assert!(events.iter().any(|e| matches!(e, StreamEvent::Token { .. })));
// Exactly one Done event at the end
let done_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, StreamEvent::Done { .. }))
.collect();
assert_eq!(done_events.len(), 1);
assert!(matches!(events.last().unwrap(), StreamEvent::Done { .. }));
}
#[tokio::test]
async fn test_chat_stream_sends_error_event_on_failure() {
let app = build_test_app_with_failing_llm().await;
let tenant = create_test_tenant(&app).await;
let response = app
.authenticated_request(tenant.api_key)
.post("/chat/stream")
// ...
.send()
.await
.unwrap();
let body = response.text().await.unwrap();
let events: Vec<StreamEvent> = parse_sse_events(&body);
// Must terminate with an error event, not a silent truncation
assert!(matches!(events.last().unwrap(), StreamEvent::Error { .. }));
}
The key invariants to test: the stream always terminates (with either Done or Error), it never silently truncates, and error events carry a human-readable message.
Token Budget Management During Streaming
Streaming complicates billing because you don't know the total token count until the stream ends. Our approach:
- Reserve a token budget at stream start based on the max expected response length
- Count actual tokens emitted during the stream
- Commit the actual count at stream end, releasing the reserved balance
// At stream start: reserve max budget
let reservation = billing.reserve_tokens(tenant_id, MAX_RESPONSE_TOKENS).await?;
// During stream: count tokens emitted
let mut tokens_emitted = 0u32;
while let Some(token) = llm_stream.next().await {
tokens_emitted += token.len() as u32; // approximate; use tiktoken for accuracy
tx.send(StreamEvent::Token { text: token }).await?;
}
// At stream end: commit actual usage
billing.commit_tokens(tenant_id, reservation.id, tokens_emitted).await?;
This prevents over-billing on short responses while still enforcing budget limits. The reservation ensures a tenant can't start 50 concurrent streams that together would exceed their budget — the reservation checks against remaining budget atomically.