AdaptJax™ Technology¶
The Translation Layer That Unlocks Google's TPU Supercomputers
AdaptJax™ is Adaptensor's proprietary JAX-native embedding engine that runs directly on Google Cloud TPUs without TensorFlow or PyTorch dependencies.
The Problem: The Static Shape Barrier¶
Google's TPUs deliver exceptional performance—180+ TFLOPS per v2-8 chip—but they're notoriously difficult to use. TPUs require:
- Static tensor shapes for XLA compilation
- Fixed batch sizes determined at compile time
- Predictable control flow without dynamic branching
Most AI workloads (PyTorch, JAX, real-world inference) are inherently dynamic:
- Variable-length sequences (text, audio, time series)
- Changing batch sizes per request
- Conditional model execution
- Ragged inputs from production systems
This mismatch causes:
| Problem | Impact |
|---|---|
| Compilation stalls | 30-60 second delays mid-inference |
| Shape explosions | Memory exhaustion from cached graphs |
| Latency spikes | Unbounded response times |
| TPU crashes | "Shape mismatch" errors |
Result: The world's most powerful AI hardware sits underutilized behind a usability wall.
The Solution: AdaptJax Middleware¶
AdaptJax inserts a translation layer between your code and the TPU that:
- Normalizes shapes via adaptive bucketing
- Enforces XLA-friendly control flow patterns
- Optimizes memory for TPU HBM architecture ┌─────────────────────────────────────────────┐ │ Your Dynamic Workloads │ │ (variable sequences, ragged batches) │ └─────────────────┬───────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────┐ │ AdaptJax™ Middleware │ │ ┌─────────────────────────────────────┐ │ │ │ Adaptive Bucketing & Padding │ │ │ │ • Variable inputs → Fixed shapes │ │ │ │ • {128, 256, 512} token buckets │ │ │ │ • Automatic mask generation │ │ │ └─────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌─────────────────────────────────────┐ │ │ │ JAX/Flax Transformer │ │ │ │ • XLA-optimized forward pass │ │ │ │ • Batched inference │ │ │ │ • Compiled graph reuse │ │ │ └─────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌─────────────────────────────────────┐ │ │ │ TPU v2-8 Execution │ │ │ │ • 180+ TFLOPS compute │ │ │ │ • 64GB HBM memory │ │ │ │ • 8 TPU cores parallel │ │ │ └─────────────────────────────────────┘ │ └─────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────┐ │ 384-dim Semantic Embeddings │ └─────────────────────────────────────────────┘
Three Pillars of AdaptJax¶
Pillar 1: Adaptive Bucketing & Padding¶
Converts dynamic inputs into a small set of fixed bucket sizes:
# Internal bucketing logic
BUCKETS = [128, 256, 512]
def pad_to_bucket(sequence, bucket_size):
"""Pad sequence to nearest bucket size."""
padded_length = ((len(sequence) // bucket_size) + 1) * bucket_size
padding = padded_length - len(sequence)
return sequence + [PAD_TOKEN] * padding, create_mask(len(sequence))
Benefits:
| Without AdaptJax | With AdaptJax |
|---|---|
| Every unique shape triggers compilation | Only 3 graph variants (one per bucket) |
| Unbounded compilation time | Compiled once, reused forever |
| 10-30% TPU utilization | Near-100% utilization |
Pillar 2: Tensor Adapters (PEFT/LoRA-style)¶
Hot-swap small adapter weights per customer while keeping one backbone model resident:
Where:
- W = frozen backbone weight (billions of parameters)
- A, B = small adapter matrices (millions of parameters)
- scale = normalization factor
Benefits:
- Fast context switches between users (~10ms)
- Single backbone serves all tenants
- 90% memory reduction vs per-user models
Pillar 3: Elastic Compute (Entropy-Based Early Exit)¶
Simple queries exit early, complex ones use full depth:
def entropy(logits):
"""Low entropy = high confidence = exit early."""
probs = jax.nn.softmax(logits)
return -jnp.sum(probs * jnp.log(probs + 1e-9))
- "What is 2+2?" → exits after 2 layers
- "Explain quantum entanglement" → uses all 12 layers
- XLA-compatible via fixed loop bounds + masking
Performance Specifications¶
| Metric | Value |
|---|---|
| Embedding Throughput | 10,000+ chunks/second |
| Query Latency | <20ms average |
| Batch Size | Up to 512 sequences |
| Max Sequence Length | 512 tokens |
| Embedding Dimension | 384 |
| Model Architecture | all-MiniLM-L6-v2 (optimized) |
Comparison: AdaptJax vs Alternatives¶
AdaptJax vs OpenAI Embeddings¶
| Feature | AdaptJax™ | OpenAI API |
|---|---|---|
| Privacy | 100% on-premise | Data sent to OpenAI |
| Cost Model | Fixed (TPU time) | Per-token pricing |
| Latency | <20ms | 200-500ms |
| Rate Limits | None | Yes (TPM limits) |
| Custom Models | Yes | No |
| Offline Use | Yes | No |
AdaptJax vs Sentence Transformers (GPU)¶
| Feature | AdaptJax™ (TPU) | Sentence Transformers (GPU) |
|---|---|---|
| Throughput | 10,000+ chunks/s | 1,000-3,000 chunks/s |
| Hourly Cost | $4.50/hr (v2-8) | $8+/hr (A100) |
| Parallelism | 8 cores native | Single GPU |
| Memory | 64GB HBM | 40-80GB |
| Setup Complexity | Managed | Self-managed |
Code Example: AdaptJax Embedder¶
"""
AdaptJax Embedder - JAX-Native Document Embeddings for TPU
==========================================================
Pure JAX/Flax implementation that runs natively on TPU without
TensorFlow or PyTorch dependencies.
"""
import jax
import jax.numpy as jnp
from flax import linen as nn
class TransformerBlock(nn.Module):
"""Single transformer block with self-attention and FFN."""
hidden_size: int = 384
num_heads: int = 6
mlp_dim: int = 1536
@nn.compact
def __call__(self, x, mask=None):
# Self-attention with residual
attn_out = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
qkv_features=self.hidden_size,
)(x, x, mask=mask)
x = nn.LayerNorm()(x + attn_out)
# Feed-forward with residual
ff_out = nn.Dense(self.mlp_dim)(x)
ff_out = nn.gelu(ff_out)
ff_out = nn.Dense(self.hidden_size)(ff_out)
x = nn.LayerNorm()(x + ff_out)
return x
class AdaptJaxEmbedder(nn.Module):
"""Lightweight transformer encoder for TPU inference."""
vocab_size: int = 30522
hidden_size: int = 384
num_layers: int = 6
@nn.compact
def __call__(self, input_ids, attention_mask):
x = nn.Embed(self.vocab_size, self.hidden_size)(input_ids)
for _ in range(self.num_layers):
x = TransformerBlock()(x, mask=attention_mask)
# Mean pooling
return jnp.mean(x * attention_mask[:, :, None], axis=1)
# JIT compile for TPU
@jax.jit
def embed_batch(params, input_ids, attention_mask):
return model.apply(params, input_ids, attention_mask)
# Parallel across all 8 TPU cores
embeddings = jax.pmap(embed_batch)(params, batched_inputs, batched_masks)
Patent Protection¶
AdaptJax™ technology is protected under U.S. Patent Application 63/930,053.
The patent covers:
"AdaptensorCore middleware - the translation layer that unlocks Google's TPU supercomputers for the 99% of AI teams who don't want to rewrite their code."
Next Steps¶
- Quick Start Guide - Get running in 5 minutes
- API Reference - Complete endpoint documentation
- Python SDK - Client library reference