Skip to content

Kernel Tuning

Per-tier Metal kernel tuning — matrix tiles, FlashAttention blocks, threadgroup sizes, chunk sizes, and batch multipliers.

PMetal automatically selects kernel parameters based on your device tier and GPU family. For several hot paths, Tuna now persists the resolved specialization per device/problem shape and compiles the Metal shader with matching function constants rather than relying only on host-side heuristics.

TierApple7–9Apple10 (M5+, NAX)
Base32×32×3264×32×32
Pro64×32×3264×64×32
Max64×64×32128×64×32
Ultra64×64×32128×64×32

These tier tables apply to the standard Metal GEMM/LoRA kernels. On Apple10/M5 hardware, the Metal 4 / MPP dispatcher now auto-tunes and persists among 32×32 / 1-simdgroup, 64×32 / 2-simdgroup, 32×64 / 2-simdgroup, and 64×64 / 4-simdgroup variants, plus Morton-vs-linear tile walk order. Aligned M/N tiles use static full-tile extents. Apple7-9 continue to use the standard Metal kernels. For 4-bit affine quantized linear inference, Apple10/M5 also benchmarks and persists MLX quantized_matmul versus the MPP path per device and problem shape. Standard Metal FlashAttention now benchmarks and persists among the known-valid per-head-dimension block pairs instead of relying only on the tier table below. For supported no-custom-mask head_dim = 64, 80, 96, and 128 inference attention shapes, including softcapped configs, Apple10/M5 also benchmarks and persists MLX fast SDPA vs Metal FlashAttention vs MPP FlashAttention, rejecting MPP candidates that diverge numerically from the MLX reference.

Baseline block-size seed per head dimension before persisted tuning:

Head DimBaseProMaxUltra
6464×3264×6464×6464×64
8032×3264×3264×3264×32
9632×3264×3264×3264×32
12832×3232×3264×3264×32
25616×1616×1632×1632×16

Tuna now benchmarks and persists THREADS_PER_TOKEN plus the tiled/non-tiled path choice for this kernel. The table below is the heuristic seed used to order the benchmark candidates.

TierThreads / TokenTiled Path
Base128out_features > 256
Pro256out_features > 256
Max512out_features > 256
Ultra512out_features > 256

Tuna now benchmarks and persists THREADS_PER_TOKEN and SWIGLU_CHUNK_SIZE for these standard-Metal kernels. The table below is the heuristic seed used to order the benchmark candidates.

TierThreads / TokenChunk Size
Base128 or 2561024 or 2048
Pro2562048 or 4096
Max5122048 or 4096
Ultra5122048 or 4096

Tuna now benchmarks and persists fused linear cross-entropy threadgroup size and default chunk size per device, dtype, and problem shape. The table below is the heuristic seed used to order the benchmark candidates.

TierThreads / TokenChunk Size
Base128 / 256 / 5121024 or 2048
Pro256 / 512 / 10242048 or 4096
Max256 / 512 / 10244096 or 8192
Ultra256 / 512 / 10244096 or 8192

CE_THREADS_PER_TOKEN still scales with vocabulary size and clamps to the hardware maximum. Base-tier devices seed at 128 / 256 / 512 across < 32k, 32k..128k, and > 128k vocabularies, while Pro / Max / Ultra seed at 256 / 512 / 1024. Wider hidden states start from a smaller chunk-size seed before benchmarking.

Tuna now benchmarks and persists threads_per_group and elements_per_thread for the standard-Metal fused-merge elementwise kernels. This path stays first-class on Apple7-Apple9 and is also reused on Apple10 for non-MPP merge workloads.

TierThreads / GroupElements / Thread
Base128 or 2562, 4, or 8
Pro128, 256, or 5124 or 8
Max256 or 5124 or 8
Ultra256 or 5124 or 8

The candidate ordering is tier-aware, but the final result is benchmarked per device and problem shape and stored in merge.json. Persistent merge and LoRA-forward cache keys now include device identity so results are safe to reuse across different Apple Silicon tiers and bins.

TierMultiplier
Base
Pro
Max
Ultra
KernelDescription
FlashAttentionO(n) memory attention with fused softmax, tier-aware block sizes
Fused GDNGated Delta Network recurrence kernel — single-pass state update
Fused LoRACombined forward pass for adapter layers (~2× speedup)
Fused Cross-EntropyChunked vocabulary loss computation
Fused Linear Cross-EntropySkips logits materialization entirely, with tuned chunk/thread specialization
Fused RoPERotary position embeddings in-kernel
Fused SwiGLUFused gate + activation with benchmarked-and-persisted Tuna thread/chunk specialization
Fused RMSNorm + LoRACombined normalization and adapter projection with benchmarked-and-persisted Tuna thread/tiled specialization
Fused SamplerJIT-compiled token sampling
Fused MLPCombined gate/up/down projections
Async SchedulerDouble/triple-buffered GPU command scheduling