MolmoAct2 / modeling_molmoact2.py
hqfang's picture
Add files using upload-large-folder tool
75c5cf5 verified
"""Modeling code for MolmoAct2"""
import json
import math
import os
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask, create_masks_for_generate
from transformers.modeling_flash_attention_utils import (
FlashAttentionKwargs,
_flash_attention_forward,
flash_attn_supports_top_left_mask,
)
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.models.auto import AutoModelForImageTextToText
from transformers.processing_utils import Unpack
from transformers.utils import (
ModelOutput,
TransformersKwargs,
can_return_tuple,
logging,
)
from .configuration_molmoact2 import (
MolmoAct2ActionExpertConfig,
MolmoAct2AdapterConfig,
MolmoAct2Config,
MolmoAct2TextConfig,
MolmoAct2VitConfig,
)
from .inference import (
ActionCudaGraphManager,
DepthDecodeCudaGraphManager,
_ActionFlowInputs,
_cache_max_len_int,
_cache_seq_len_int,
_iter_cache_key_values,
)
logger = logging.get_logger(__name__)
ACTION_START_TOKEN = "<action_start>"
ACTION_END_TOKEN = "<action_end>"
ACTION_OUTPUT_TOKEN = "<action_output>"
STATE_START_TOKEN = "<state_start>"
STATE_END_TOKEN = "<state_end>"
STATE_TOKEN_PREFIX = "<state_"
DEPTH_START_TOKEN = "<depth_start>"
DEPTH_END_TOKEN = "<depth_end>"
DEPTH_OUTPUT_TOKEN = "<depth_output>"
DEPTH_TOKEN_PREFIX = "<depth_"
SETUP_START_TOKEN = "<setup_start>"
SETUP_END_TOKEN = "<setup_end>"
CONTROL_START_TOKEN = "<control_start>"
CONTROL_END_TOKEN = "<control_end>"
_QUESTION_TRAILING_SENTENCE_PUNCTUATION = ".,!?;:,…"
_QUESTION_TRAILING_CLOSERS = "\"'”’)]}"
_QUESTION_SURROUNDING_DELIMITERS = "\"'`“”‘’[](){}"
_QUESTION_PREFIX_PATTERNS = tuple(
re.compile(pattern, flags=re.IGNORECASE)
for pattern in (
r"^(?:task|instruction|language[_ ]instruction|goal)\s*[:\-]\s*",
r"^(?:the\s+task\s+is\s+to|your\s+task\s+is\s+to)\s+",
)
)
_DEPTH_REASONING_PATCH_SIZE = 32
_DEPTH_REASONING_THRESHOLD = 0.996
def _modulate(
x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def _round_up_multiple(value: int, multiple_of: int) -> int:
if multiple_of <= 0:
return value
return int(math.ceil(value / multiple_of) * multiple_of)
def _init_linear(linear: nn.Linear, *, zero: bool = False, scale: float = 1.0) -> None:
if zero:
nn.init.zeros_(linear.weight)
else:
nn.init.xavier_uniform_(linear.weight)
if scale != 1.0:
with torch.no_grad():
linear.weight.mul_(scale)
if linear.bias is not None:
nn.init.zeros_(linear.bias)
@dataclass
class ActionExpertContext:
kv_contexts: Sequence[Tuple[torch.Tensor, torch.Tensor]]
cross_mask: Optional[torch.Tensor]
self_mask: Optional[torch.Tensor]
valid_action: Optional[torch.Tensor]
rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
@dataclass
class ActionExpertStepModulation:
conditioning: torch.Tensor
block_modulations: Sequence[Tuple[torch.Tensor, ...]]
final_modulation: Tuple[torch.Tensor, torch.Tensor]
class ActionExpertRMSNorm(nn.Module):
def __init__(
self,
size: int,
*,
eps: float = 1e-6,
elementwise_affine: bool = False,
device=None,
) -> None:
super().__init__()
self.size = size
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(size, device=device))
else:
self.register_parameter("weight", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.autocast(enabled=False, device_type=x.device.type):
dtype = x.dtype
x_float = x.to(torch.float32)
variance = x_float.pow(2).mean(dim=-1, keepdim=True)
out = x_float * torch.rsqrt(variance + self.eps)
out = out.to(dtype)
if self.weight is not None:
out = out * self.weight
return out
def reset_parameters(self) -> None:
if self.weight is not None:
nn.init.ones_(self.weight)
class ActionExpertRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, base: float = 10000.0) -> None:
super().__init__()
if head_dim % 2 != 0:
raise ValueError("RoPE requires an even head_dim.")
self.head_dim = head_dim
self.base = base
def build_cache(
self,
*,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
half_dim = self.head_dim // 2
inv_freq = 1.0 / (
self.base
** (
torch.arange(0, half_dim, device=device, dtype=torch.float32)
/ max(half_dim, 1)
)
)
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(positions, inv_freq)
cos = freqs.cos().to(dtype=dtype).view(1, 1, seq_len, half_dim)
sin = freqs.sin().to(dtype=dtype).view(1, 1, seq_len, half_dim)
return cos, sin
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
*,
rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if rope_cache is None:
rope_cache = self.build_cache(
seq_len=q.shape[-2], device=q.device, dtype=q.dtype
)
cos, sin = rope_cache
half_dim = self.head_dim // 2
def _apply(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., :half_dim], x[..., half_dim:]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
return _apply(q), _apply(k)
class ActionExpertSelfAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
*,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
qk_norm: bool = True,
qk_norm_eps: float = 1e-6,
use_rope: bool = True,
) -> None:
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.attn_dropout = attn_dropout
self.q_norm = (
ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None
)
self.k_norm = (
ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None
)
self.rope = ActionExpertRotaryEmbedding(self.head_dim) if use_rope else None
self.qkv = nn.Linear(hidden_size, hidden_size * 3)
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.out_drop = nn.Dropout(proj_dropout)
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.q_norm is None or self.k_norm is None:
return q, k
return self.q_norm(q), self.k_norm(k)
def _attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
) -> torch.Tensor:
dropout_p = self.attn_dropout if self.training else 0.0
out = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
return out.transpose(1, 2).contiguous()
def forward(
self,
x: torch.Tensor,
*,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
bsz, seq_len, _ = x.shape
qkv = self.qkv(x).view(bsz, seq_len, 3, self.num_heads, self.head_dim)
q = qkv[:, :, 0].transpose(1, 2)
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].contiguous()
q, k = self._apply_qk_norm(q, k)
if self.rope is not None:
q, k = self.rope(q, k, rope_cache=rope_cache)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
out = self._attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal)
out = out.reshape(bsz, seq_len, self.hidden_size)
return self.out_drop(self.out_proj(out))
class ActionExpertCrossAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
*,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
qk_norm: bool = True,
qk_norm_eps: float = 1e-6,
) -> None:
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.attn_dropout = attn_dropout
self.q_norm = (
ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None
)
self.k_norm = (
ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None
)
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.kv_proj = nn.Linear(hidden_size, hidden_size * 2)
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.out_drop = nn.Dropout(proj_dropout)
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.q_norm is None or self.k_norm is None:
return q, k
return self.q_norm(q), self.k_norm(k)
def _as_heads(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() == 4:
if x.shape[2] == self.num_heads:
return x
if x.shape[1] == self.num_heads:
return x.transpose(1, 2).contiguous()
raise ValueError(f"Unexpected cross-attention KV shape {tuple(x.shape)}")
if x.dim() != 3:
raise ValueError(f"Expected 3D/4D cross-attention KV, got {tuple(x.shape)}")
bsz, seq_len, _ = x.shape
return x.view(bsz, seq_len, self.num_heads, self.head_dim)
def _attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
dropout_p = self.attn_dropout if self.training else 0.0
out = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=False,
)
return out.transpose(1, 2).contiguous()
def forward(
self,
x: torch.Tensor,
*,
kv: Optional[torch.Tensor] = None,
kv_k: Optional[torch.Tensor] = None,
kv_v: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if (kv_k is None) != (kv_v is None):
raise ValueError("kv_k and kv_v must both be provided or both be None.")
if kv is not None and kv_k is not None:
raise ValueError("Provide either kv or kv_k/kv_v, not both.")
bsz, tgt_len, _ = x.shape
q = self.q_proj(x).view(bsz, tgt_len, self.num_heads, self.head_dim)
if kv_k is not None and kv_v is not None:
k = self._as_heads(kv_k)
v = self._as_heads(kv_v)
k_pre_normed = True
else:
if kv is None:
raise ValueError("cross-attention requires kv or kv_k/kv_v.")
src_len = kv.shape[1]
kv_proj = self.kv_proj(kv).view(
bsz, src_len, 2, self.num_heads, self.head_dim
)
k = kv_proj[:, :, 0]
v = kv_proj[:, :, 1]
k_pre_normed = False
q = q.transpose(1, 2)
k = k.transpose(1, 2)
if k_pre_normed:
if self.q_norm is not None:
q = self.q_norm(q)
else:
q, k = self._apply_qk_norm(q, k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
out = self._attention(q, k, v, attn_mask=attn_mask)
out = out.reshape(bsz, tgt_len, self.hidden_size)
return self.out_drop(self.out_proj(out))
class ActionExpertMLP(nn.Module):
def __init__(
self,
hidden_size: int,
*,
mlp_ratio: float,
multiple_of: int,
dropout: float = 0.0,
) -> None:
super().__init__()
inner_dim = _round_up_multiple(int(hidden_size * mlp_ratio), multiple_of)
self.up_proj = nn.Linear(hidden_size, inner_dim)
self.gate_proj = nn.Linear(hidden_size, inner_dim)
self.down_proj = nn.Linear(inner_dim, hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.gate_proj(x)) * self.up_proj(x)
x = self.dropout(x)
x = self.down_proj(x)
return self.dropout(x)
class ActionExpertModulation(nn.Module):
def __init__(self, hidden_size: int, num_chunks: int) -> None:
super().__init__()
self.act = nn.SiLU()
self.linear = nn.Linear(hidden_size, num_chunks * hidden_size)
def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
return self.linear(self.act(conditioning))
class ActionExpertBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
*,
mlp_ratio: float,
ffn_multiple_of: int,
attn_dropout: float = 0.0,
dropout: float = 0.0,
qk_norm: bool = True,
qk_norm_eps: float = 1e-6,
rope: bool = True,
) -> None:
super().__init__()
self.self_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6)
self.cross_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6)
self.ff_norm = ActionExpertRMSNorm(hidden_size, eps=1e-6)
self.self_attn = ActionExpertSelfAttention(
hidden_size,
num_heads,
attn_dropout=attn_dropout,
proj_dropout=dropout,
qk_norm=qk_norm,
qk_norm_eps=qk_norm_eps,
use_rope=rope,
)
self.cross_attn = ActionExpertCrossAttention(
hidden_size,
num_heads,
attn_dropout=attn_dropout,
proj_dropout=dropout,
qk_norm=qk_norm,
qk_norm_eps=qk_norm_eps,
)
self.mlp = ActionExpertMLP(
hidden_size,
mlp_ratio=mlp_ratio,
multiple_of=ffn_multiple_of,
dropout=dropout,
)
self.modulation = ActionExpertModulation(hidden_size, 9)
def forward(
self,
x: torch.Tensor,
conditioning: torch.Tensor,
*,
cross_kv: Tuple[torch.Tensor, torch.Tensor],
self_attn_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
modulation: Optional[Tuple[torch.Tensor, ...]] = None,
rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
if modulation is None:
modulation = self.modulation(conditioning).chunk(9, dim=1)
(
shift_msa,
scale_msa,
gate_msa,
shift_mca,
scale_mca,
gate_mca,
shift_mlp,
scale_mlp,
gate_mlp,
) = modulation
x = x + gate_msa.unsqueeze(1) * self.self_attn(
_modulate(self.self_norm(x), shift_msa, scale_msa),
attn_mask=self_attn_mask,
is_causal=is_causal,
rope_cache=rope_cache,
)
x = x + gate_mca.unsqueeze(1) * self.cross_attn(
_modulate(self.cross_norm(x), shift_mca, scale_mca),
kv_k=cross_kv[0],
kv_v=cross_kv[1],
attn_mask=attn_mask,
)
x = x + gate_mlp.unsqueeze(1) * self.mlp(
_modulate(self.ff_norm(x), shift_mlp, scale_mlp)
)
return x
class ActionExpertFinalLayer(nn.Module):
def __init__(self, hidden_size: int, output_dim: int) -> None:
super().__init__()
self.norm = ActionExpertRMSNorm(hidden_size, eps=1e-6)
self.modulation = ActionExpertModulation(hidden_size, 2)
self.linear = nn.Linear(hidden_size, output_dim)
def forward(
self,
x: torch.Tensor,
conditioning: torch.Tensor,
*,
modulation: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
if modulation is None:
modulation = self.modulation(conditioning).chunk(2, dim=1)
shift, scale = modulation
return self.linear(_modulate(self.norm(x), shift, scale))
class SinusoidalTimeEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
if timesteps.dim() > 1:
timesteps = timesteps.view(timesteps.shape[0], -1)[:, 0]
half_dim = self.dim // 2
freq = torch.exp(
torch.arange(half_dim, device=timesteps.device, dtype=timesteps.dtype)
* (-math.log(10000.0) / max(half_dim - 1, 1))
)
args = timesteps[:, None] * freq[None, :]
emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if self.dim % 2 == 1:
emb = F.pad(emb, (0, 1))
return emb
class ActionExpert(nn.Module):
"""Modern MolmoAct2 action expert, embedded for HF remote-code inference."""
def __init__(
self,
config: MolmoAct2ActionExpertConfig,
*,
llm_dim: int,
llm_kv_dim: int,
llm_num_layers: int,
device=None,
):
super().__init__()
if config.implementation != "new":
raise ValueError("Only action_expert.implementation='new' is supported.")
if config.num_layers != llm_num_layers:
raise ValueError(
"MolmoAct2 HF action expert supports only per-layer conditioning with one "
f"action block per LLM layer (action={config.num_layers}, llm={llm_num_layers})."
)
self.config = config
self.hidden_size = config.hidden_size
self.llm_dim = llm_dim
self.llm_kv_dim = llm_kv_dim
self.action_head_dim = config.hidden_size // config.num_heads
self.time_embed = nn.Sequential(
SinusoidalTimeEmbedding(config.timestep_embed_dim),
nn.Linear(config.timestep_embed_dim, config.hidden_size, device=device),
nn.SiLU(),
nn.Linear(config.hidden_size, config.hidden_size, device=device),
)
self.action_embed = nn.Linear(
config.max_action_dim, config.hidden_size, device=device
)
self.state_encoder = nn.Linear(
config.hidden_size, config.hidden_size, device=device
)
self.state_norm = ActionExpertRMSNorm(
config.hidden_size, eps=1e-6, device=device
)
self.context_k_proj = nn.Linear(
self.llm_kv_dim, config.hidden_size, bias=False, device=device
)
self.context_v_proj = nn.Linear(
self.llm_kv_dim, config.hidden_size, bias=False, device=device
)
self.context_norm = (
ActionExpertRMSNorm(config.hidden_size, eps=1e-6)
if config.context_layer_norm
else nn.Identity()
)
self._modulation_cache_key: Optional[Tuple[Any, ...]] = None
self._modulation_cache_value: Optional[Sequence[ActionExpertStepModulation]] = (
None
)
self.blocks = nn.ModuleList(
[
ActionExpertBlock(
config.hidden_size,
config.num_heads,
mlp_ratio=config.mlp_ratio,
ffn_multiple_of=config.ffn_multiple_of,
attn_dropout=config.attn_dropout,
dropout=config.dropout,
qk_norm=config.qk_norm,
qk_norm_eps=config.qk_norm_eps,
rope=config.rope,
)
for _ in range(config.num_layers)
]
)
for block in self.blocks:
block.cross_attn.kv_proj.weight.requires_grad = False
if block.cross_attn.kv_proj.bias is not None:
block.cross_attn.kv_proj.bias.requires_grad = False
self.final_layer = ActionExpertFinalLayer(
config.hidden_size, config.max_action_dim
)
self.reset_parameters()
def reset_parameters(self) -> None:
for module in self.time_embed.modules():
if isinstance(module, nn.Linear):
_init_linear(module)
_init_linear(self.action_embed)
_init_linear(self.state_encoder)
self.state_norm.reset_parameters()
_init_linear(self.context_k_proj)
_init_linear(self.context_v_proj)
if isinstance(self.context_norm, ActionExpertRMSNorm):
self.context_norm.reset_parameters()
residual_scale = (2 * max(self.config.num_layers, 1)) ** -0.5
for block in self.blocks:
_init_linear(block.self_attn.qkv)
_init_linear(block.self_attn.out_proj, scale=residual_scale)
_init_linear(block.cross_attn.q_proj)
_init_linear(block.cross_attn.kv_proj)
_init_linear(block.cross_attn.out_proj, scale=residual_scale)
_init_linear(block.mlp.up_proj)
_init_linear(block.mlp.gate_proj)
_init_linear(block.mlp.down_proj, scale=residual_scale)
_init_linear(block.modulation.linear, zero=True)
block.self_norm.reset_parameters()
block.cross_norm.reset_parameters()
block.ff_norm.reset_parameters()
if block.self_attn.q_norm is not None:
block.self_attn.q_norm.reset_parameters()
if block.self_attn.k_norm is not None:
block.self_attn.k_norm.reset_parameters()
if block.cross_attn.q_norm is not None:
block.cross_attn.q_norm.reset_parameters()
if block.cross_attn.k_norm is not None:
block.cross_attn.k_norm.reset_parameters()
self.final_layer.norm.reset_parameters()
_init_linear(self.final_layer.modulation.linear, zero=True)
_init_linear(self.final_layer.linear, zero=True)
def _reshape_hidden_to_heads(self, x: torch.Tensor) -> torch.Tensor:
return x.view(
x.shape[0], x.shape[1], self.config.num_heads, self.action_head_dim
)
def _encode_states(self, states: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if states is None:
return None
if states.dim() == 2:
states = states.unsqueeze(1)
if states.shape[-1] != self.hidden_size:
feat_dim = states.shape[-1]
if feat_dim < self.hidden_size:
states = F.pad(states, (0, self.hidden_size - feat_dim))
else:
states = states[..., : self.hidden_size]
return self.state_norm(self.state_encoder(states))
def _project_kv_tensor(self, x: torch.Tensor, proj: nn.Linear) -> torch.Tensor:
flat = self.context_norm(proj(x))
return self._reshape_hidden_to_heads(flat)
def _prepare_kv_context(
self,
encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]],
encoded_states: Optional[torch.Tensor],
) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]:
if len(encoder_kv_states) != len(self.blocks):
raise ValueError(
f"Expected {len(self.blocks)} KV layers for per-layer conditioning, "
f"got {len(encoder_kv_states)}."
)
kv_contexts = []
state_heads = (
self._reshape_hidden_to_heads(encoded_states)
if encoded_states is not None
else None
)
for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states):
k_ctx = self._project_kv_tensor(k_in, self.context_k_proj)
v_ctx = self._project_kv_tensor(v_in, self.context_v_proj)
if state_heads is not None:
k_ctx = torch.cat([k_ctx, state_heads], dim=1)
v_ctx = torch.cat([v_ctx, state_heads], dim=1)
k_norm = block.cross_attn.k_norm
if k_norm is not None:
k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2)
kv_contexts.append((k_ctx, v_ctx))
return kv_contexts
@staticmethod
def _build_cross_attention_mask(
encoder_attention_mask: Optional[torch.Tensor],
encoded_states: Optional[torch.Tensor],
batch_size: int,
dtype: torch.dtype,
) -> Optional[torch.Tensor]:
state_seq_len = 0 if encoded_states is None else encoded_states.shape[1]
if encoder_attention_mask is None:
return None
mask = encoder_attention_mask[:, None, None, :].to(dtype=dtype)
if state_seq_len > 0:
ones = torch.ones(
batch_size,
1,
1,
state_seq_len,
device=mask.device,
dtype=mask.dtype,
)
mask = torch.cat([mask, ones], dim=-1)
return (1.0 - mask) * torch.finfo(dtype).min
def _build_self_attention_mask(
self,
action_attention_mask: Optional[torch.Tensor],
seq_len: int,
device: torch.device,
dtype: torch.dtype,
) -> Optional[torch.Tensor]:
mask = None
if action_attention_mask is not None:
valid = action_attention_mask.to(device=device, dtype=torch.bool)
key_mask = (~valid)[:, None, None, :].to(dtype=dtype)
mask = key_mask * torch.finfo(dtype).min
if self.config.causal_attn:
causal = torch.ones(seq_len, seq_len, device=device, dtype=torch.bool).triu(
diagonal=1
)
causal = (
causal.unsqueeze(0).unsqueeze(0).to(dtype=dtype)
* torch.finfo(dtype).min
)
mask = causal if mask is None else mask + causal
return mask
def prepare_context(
self,
*,
encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]],
encoder_attention_mask: Optional[torch.Tensor] = None,
action_attention_mask: Optional[torch.Tensor] = None,
state_embeddings: Optional[torch.Tensor] = None,
batch_size: int,
seq_len: int,
device: torch.device,
dtype: torch.dtype,
) -> ActionExpertContext:
encoded_states = self._encode_states(state_embeddings)
valid_action = None
if action_attention_mask is not None:
valid_action = action_attention_mask.to(
device=device, dtype=dtype
).unsqueeze(-1)
rope_cache = None
if len(self.blocks) > 0 and self.blocks[0].self_attn.rope is not None:
rope_cache = self.blocks[0].self_attn.rope.build_cache(
seq_len=seq_len,
device=device,
dtype=dtype,
)
kv_contexts = self._prepare_kv_context(encoder_kv_states, encoded_states)
cross_mask = self._build_cross_attention_mask(
encoder_attention_mask,
encoded_states,
batch_size,
dtype,
)
self_mask = self._build_self_attention_mask(
action_attention_mask, seq_len, device, dtype
)
return ActionExpertContext(
kv_contexts=kv_contexts,
cross_mask=cross_mask,
self_mask=self_mask,
valid_action=valid_action,
rope_cache=rope_cache,
)
def prepare_modulation_cache(
self,
timesteps: Sequence[torch.Tensor],
) -> Sequence[ActionExpertStepModulation]:
cache = []
for idx, step_t in enumerate(timesteps):
conditioning = self.time_embed(step_t)
block_modulations = []
for block in self.blocks:
block_modulations.append(
tuple(block.modulation(conditioning).chunk(9, dim=1))
)
final_modulation = tuple(
self.final_layer.modulation(conditioning).chunk(2, dim=1)
)
cache.append(
ActionExpertStepModulation(
conditioning=conditioning,
block_modulations=block_modulations,
final_modulation=final_modulation,
)
)
return cache
def get_or_prepare_modulation_cache(
self,
timesteps: Sequence[torch.Tensor],
*,
cache_key: Optional[Tuple[Any, ...]] = None,
) -> Sequence[ActionExpertStepModulation]:
if self.training or cache_key is None:
return self.prepare_modulation_cache(timesteps)
if (
self._modulation_cache_key == cache_key
and self._modulation_cache_value is not None
):
return self._modulation_cache_value
cached = self.prepare_modulation_cache(timesteps)
self._modulation_cache_key = cache_key
self._modulation_cache_value = cached
return cached
def forward_with_context(
self,
actions: torch.Tensor,
timesteps: torch.Tensor,
*,
context: ActionExpertContext,
modulation: Optional[ActionExpertStepModulation] = None,
) -> torch.Tensor:
bsz, seq_len, _ = actions.shape
if seq_len > self.config.max_horizon:
raise ValueError(
f"Action sequence length {seq_len} exceeds configured max_horizon={self.config.max_horizon}"
)
if modulation is None:
conditioning = self.time_embed(timesteps)
block_modulations: Sequence[Optional[Tuple[torch.Tensor, ...]]] = [
None
] * len(self.blocks)
final_modulation = None
else:
conditioning = modulation.conditioning
block_modulations = modulation.block_modulations
final_modulation = modulation.final_modulation
x = self.action_embed(actions)
if context.valid_action is not None:
x = x * context.valid_action
for idx, (block, kv_context, block_modulation) in enumerate(
zip(self.blocks, context.kv_contexts, block_modulations)
):
x = block(
x,
conditioning,
cross_kv=kv_context,
self_attn_mask=context.self_mask,
attn_mask=context.cross_mask,
is_causal=self.config.causal_attn,
modulation=block_modulation,
rope_cache=context.rope_cache,
)
if context.valid_action is not None:
x = x * context.valid_action
out = self.final_layer(x, conditioning, modulation=final_modulation)
if context.valid_action is not None:
out = out * context.valid_action
return out
def forward(
self,
actions: torch.Tensor,
timesteps: torch.Tensor,
*,
encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]],
encoder_attention_mask: Optional[torch.Tensor] = None,
action_attention_mask: Optional[torch.Tensor] = None,
state_embeddings: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bsz, seq_len, _ = actions.shape
context = self.prepare_context(
encoder_kv_states=encoder_kv_states,
encoder_attention_mask=encoder_attention_mask,
action_attention_mask=action_attention_mask,
state_embeddings=state_embeddings,
batch_size=bsz,
seq_len=seq_len,
device=actions.device,
dtype=actions.dtype,
)
return self.forward_with_context(actions, timesteps, context=context)
def _to_numpy(value: Any) -> np.ndarray:
if isinstance(value, np.ndarray):
return value
if torch.is_tensor(value):
return value.detach().cpu().numpy()
return np.asarray(value)
def _to_array(value: Any) -> Optional[np.ndarray]:
if value is None:
return None
if torch.is_tensor(value):
return value.detach().cpu().numpy().astype(np.float32, copy=False)
return np.asarray(value, dtype=np.float32)
def _to_mask(value: Any, fallback_like: Optional[np.ndarray]) -> Optional[np.ndarray]:
if value is None:
return None
mask = np.asarray(value, dtype=np.bool_)
if fallback_like is not None and mask.shape != fallback_like.shape:
mask = np.broadcast_to(mask, fallback_like.shape)
return mask
def _feature_dim_from_stats(stats: Optional[Mapping[str, Any]]) -> Optional[int]:
if not isinstance(stats, Mapping):
return None
for key in (
"mean",
"std",
"min",
"max",
"q01",
"q99",
"q10",
"q90",
"mask",
"names",
):
value = stats.get(key)
if value is None:
continue
arr = np.asarray(value)
if arr.shape:
return int(arr.shape[-1])
if isinstance(value, Sequence) and not isinstance(value, (str, bytes)):
return int(len(value))
return None
class _FeatureNormalizer:
def __init__(
self,
*,
mode: str,
mean: Optional[np.ndarray] = None,
std: Optional[np.ndarray] = None,
min_val: Optional[np.ndarray] = None,
max_val: Optional[np.ndarray] = None,
q_low: Optional[np.ndarray] = None,
q_high: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
zero_mask: Optional[np.ndarray] = None,
):
self.mode = mode
self.mean = mean
self.std = std
self.min_val = min_val
self.max_val = max_val
self.q_low = q_low
self.q_high = q_high
self.mask = mask
self.zero_mask = zero_mask
@classmethod
def from_stats(
cls, stats: Optional[Mapping[str, Any]], mode: str
) -> Optional["_FeatureNormalizer"]:
if stats is None:
return None
raw_mask = stats.get("mask") if isinstance(stats, Mapping) else None
if mode == "none":
fallback = None
for key in (
"mean",
"std",
"min",
"max",
"q01",
"q99",
"q10",
"q90",
"mask",
):
fallback = _to_array(stats.get(key))
if fallback is not None:
break
return cls(mode=mode, mask=_to_mask(raw_mask, fallback))
if mode == "mean_std":
mean = _to_array(stats.get("mean"))
std = _to_array(stats.get("std"))
if mean is None or std is None:
raise ValueError("norm_mode='mean_std' requires mean and std stats.")
return cls(mode=mode, mean=mean, std=std, mask=_to_mask(raw_mask, mean))
if mode == "min_max":
min_val = _to_array(stats.get("min"))
max_val = _to_array(stats.get("max"))
if min_val is None or max_val is None:
raise ValueError("norm_mode='min_max' requires min and max stats.")
return cls(
mode=mode,
min_val=min_val,
max_val=max_val,
mask=_to_mask(raw_mask, min_val),
zero_mask=(min_val == max_val),
)
if mode in {"q01_q99", "q10_q90"}:
low_key, high_key = ("q01", "q99") if mode == "q01_q99" else ("q10", "q90")
q_low = _to_array(stats.get(low_key))
q_high = _to_array(stats.get(high_key))
if q_low is None or q_high is None:
raise ValueError(
f"norm_mode={mode!r} requires {low_key} and {high_key} stats."
)
min_val = _to_array(stats.get("min"))
max_val = _to_array(stats.get("max"))
fallback = min_val if min_val is not None else q_low
zero_mask = (
None if min_val is None or max_val is None else (min_val == max_val)
)
return cls(
mode=mode,
min_val=min_val,
max_val=max_val,
q_low=q_low,
q_high=q_high,
mask=_to_mask(raw_mask, fallback),
zero_mask=zero_mask,
)
raise ValueError(f"Unsupported robot normalization mode {mode!r}.")
def normalize(self, x: Any) -> Any:
arr = _to_array(x)
if arr is None:
return None
eps = 1e-6
if self.mode == "none":
normed = arr
elif self.mode == "mean_std":
normed = (arr - self.mean) / np.maximum(self.std, eps)
elif self.mode == "min_max":
normed = (
2.0
* (arr - self.min_val)
/ np.maximum(self.max_val - self.min_val, eps)
- 1.0
)
elif self.mode in {"q01_q99", "q10_q90"}:
normed = (
2.0 * (arr - self.q_low) / np.maximum(self.q_high - self.q_low, eps)
- 1.0
)
else:
normed = arr
if self.mode in {"min_max", "q01_q99", "q10_q90"}:
normed = np.clip(normed, -1.0, 1.0)
if self.mask is not None:
normed = np.where(self.mask, normed, arr)
if self.zero_mask is not None:
normed = np.where(self.zero_mask, 0.0, normed)
if torch.is_tensor(x):
return torch.as_tensor(normed, device=x.device, dtype=x.dtype)
return normed
def unnormalize(self, x: Any) -> Any:
arr = _to_array(x)
if arr is None:
return None
if self.mode in {"min_max", "q01_q99", "q10_q90"}:
arr = np.clip(arr, -1.0, 1.0)
if self.mode == "none":
out = arr
elif self.mode == "mean_std":
out = arr * self.std + self.mean
elif self.mode == "min_max":
out = (arr + 1.0) * (self.max_val - self.min_val) / 2.0 + self.min_val
elif self.mode in {"q01_q99", "q10_q90"}:
out = (arr + 1.0) * (self.q_high - self.q_low) / 2.0 + self.q_low
else:
out = arr
if self.mask is not None:
out = np.where(self.mask, out, arr)
if torch.is_tensor(x):
return torch.as_tensor(out, device=x.device, dtype=x.dtype)
return out
class _RobotStats:
def __init__(self, payload: Mapping[str, Any]):
self.norm_mode = str(payload.get("norm_mode", "min_max"))
self.metadata_by_tag: Dict[str, Dict[str, Any]] = {
str(tag): dict(metadata or {})
for tag, metadata in dict(payload.get("metadata_by_tag") or {}).items()
}
self.action_normalizers = {}
self.state_normalizers = {}
for tag, metadata in self.metadata_by_tag.items():
if metadata.get("action_stats") is not None:
self.action_normalizers[tag] = _FeatureNormalizer.from_stats(
metadata.get("action_stats"),
self.norm_mode,
)
if metadata.get("state_stats") is not None:
self.state_normalizers[tag] = _FeatureNormalizer.from_stats(
metadata.get("state_stats"),
self.norm_mode,
)
def validate_tag(self, norm_tag: Optional[str]) -> str:
tag = str(norm_tag or "").strip()
if not tag:
raise ValueError("MolmoAct2 `predict_action` requires `norm_tag`.")
if tag not in self.metadata_by_tag:
allowed = ", ".join(sorted(self.metadata_by_tag))
raise ValueError(
f"Unknown MolmoAct2 normalization tag {tag!r}. Allowed tags: {allowed}."
)
return tag
def get_metadata(self, norm_tag: Optional[str]) -> Dict[str, Any]:
if norm_tag is None:
return {}
return dict(self.metadata_by_tag.get(str(norm_tag), {}) or {})
def normalize_state(self, state: Any, norm_tag: str) -> Any:
normalizer = self.state_normalizers.get(str(norm_tag))
return state if normalizer is None else normalizer.normalize(state)
def unnormalize_action(self, action: Any, norm_tag: str) -> Any:
normalizer = self.action_normalizers.get(str(norm_tag))
return action if normalizer is None else normalizer.unnormalize(action)
def get_action_dim(self, norm_tag: str) -> Optional[int]:
metadata = self.get_metadata(norm_tag)
stats = metadata.get("action_stats")
dim = _feature_dim_from_stats(stats)
return dim
def get_state_dim(self, norm_tag: str) -> Optional[int]:
metadata = self.get_metadata(norm_tag)
return _feature_dim_from_stats(metadata.get("state_stats"))
def get_action_horizon(self, norm_tag: str) -> Optional[int]:
return self._get_positive_int(norm_tag, "action_horizon")
def get_n_action_steps(self, norm_tag: str) -> Optional[int]:
return self._get_positive_int(norm_tag, "n_action_steps")
def _get_positive_int(self, norm_tag: str, key: str) -> Optional[int]:
value = self.get_metadata(norm_tag).get(key)
if value is None:
return None
value = int(value)
if value < 1:
raise ValueError(
f"Robot metadata for norm_tag={norm_tag!r} must define {key} >= 1."
)
return value
def _normalize_image_for_cache(image: Any) -> np.ndarray:
arr = np.asarray(image)
if arr.ndim == 2:
arr = np.stack([arr] * 3, axis=-1)
if arr.ndim == 3 and arr.shape[0] in {1, 3, 4} and arr.shape[-1] not in {1, 3, 4}:
arr = np.moveaxis(arr, 0, -1)
if arr.ndim == 3 and arr.shape[-1] == 1:
arr = np.repeat(arr, 3, axis=-1)
if arr.dtype in (np.float32, np.float64):
if arr.size > 0 and float(arr.max()) <= 1.0:
arr = arr * 255.0
arr = np.clip(arr, 0, 255).astype(np.uint8)
elif arr.dtype != np.uint8:
arr = np.clip(arr, 0, 255).astype(np.uint8)
return arr
def _extract_first_image(images: Any) -> Optional[np.ndarray]:
if images is None:
return None
if isinstance(images, (list, tuple)):
if not images:
return None
return _normalize_image_for_cache(images[0])
arr = _to_numpy(images)
if arr.ndim == 4:
return _normalize_image_for_cache(arr[0])
return _normalize_image_for_cache(arr)
def _resize_depth_reasoning_image(image: np.ndarray, target_size: int) -> np.ndarray:
from PIL import Image
if image.shape[0] == target_size and image.shape[1] == target_size:
return image
pil_image = Image.fromarray(np.asarray(image, dtype=np.uint8))
return np.asarray(pil_image.resize((target_size, target_size), Image.BILINEAR))
def _compute_depth_update_mask(
current_image: np.ndarray,
previous_image: np.ndarray,
*,
num_depth_codes: int,
) -> np.ndarray:
grid_side = int(math.isqrt(int(num_depth_codes)))
if grid_side * grid_side != int(num_depth_codes):
raise ValueError(
f"enable_adaptive_depth=True requires a square depth grid, got num_depth_codes={int(num_depth_codes)}."
)
target_size = grid_side * _DEPTH_REASONING_PATCH_SIZE
current_resized = _resize_depth_reasoning_image(current_image, target_size).astype(
np.float32
)
previous_resized = _resize_depth_reasoning_image(
previous_image, target_size
).astype(np.float32)
current_patches = (
current_resized.reshape(
grid_side,
_DEPTH_REASONING_PATCH_SIZE,
grid_side,
_DEPTH_REASONING_PATCH_SIZE,
3,
)
.transpose(0, 2, 1, 3, 4)
.reshape(grid_side, grid_side, -1)
)
previous_patches = (
previous_resized.reshape(
grid_side,
_DEPTH_REASONING_PATCH_SIZE,
grid_side,
_DEPTH_REASONING_PATCH_SIZE,
3,
)
.transpose(0, 2, 1, 3, 4)
.reshape(grid_side, grid_side, -1)
)
dot = np.sum(current_patches * previous_patches, axis=-1)
norm_current = np.linalg.norm(current_patches, axis=-1)
norm_previous = np.linalg.norm(previous_patches, axis=-1)
denom = norm_current * norm_previous
similarity = np.where(denom < 1e-8, 1.0, dot / (denom + 1e-12))
return np.asarray(similarity < _DEPTH_REASONING_THRESHOLD, dtype=np.bool_).reshape(
-1
)
def _build_depth_update_spans(
update_mask: Sequence[bool],
) -> List[Tuple[int, int, bool]]:
flat_mask = np.asarray(update_mask, dtype=np.bool_).reshape(-1)
if flat_mask.size == 0:
return []
spans: List[Tuple[int, int, bool]] = []
start = 0
current_value = bool(flat_mask[0])
for idx in range(1, int(flat_mask.shape[0])):
next_value = bool(flat_mask[idx])
if next_value == current_value:
continue
spans.append((start, idx, current_value))
start = idx
current_value = next_value
spans.append((start, int(flat_mask.shape[0]), current_value))
return spans
def _wrap_setup_text(setup_type: str, add_setup_tokens: bool = False) -> str:
setup_type = str(setup_type or "")
if setup_type.startswith(SETUP_START_TOKEN) and setup_type.endswith(
SETUP_END_TOKEN
):
return setup_type
if not setup_type or not add_setup_tokens:
return setup_type
return f"{SETUP_START_TOKEN}{setup_type}{SETUP_END_TOKEN}"
def _wrap_control_text(control_mode: str, add_control_tokens: bool = False) -> str:
control_mode = str(control_mode or "")
if control_mode.startswith(CONTROL_START_TOKEN) and control_mode.endswith(
CONTROL_END_TOKEN
):
return control_mode
if not control_mode or not add_control_tokens:
return control_mode
return f"{CONTROL_START_TOKEN}{control_mode}{CONTROL_END_TOKEN}"
def _discretize_normalized_state(
state: np.ndarray, num_state_tokens: int
) -> np.ndarray:
arr = np.asarray(state, dtype=np.float32)
arr = np.nan_to_num(arr, nan=0.0, posinf=1.0, neginf=-1.0)
arr = np.clip(arr, -1.0, 1.0)
scaled = (arr + 1.0) / 2.0 * float(num_state_tokens - 1)
return np.clip(np.rint(scaled).astype(np.int64), 0, int(num_state_tokens) - 1)
def _build_discrete_state_string(
state: Optional[np.ndarray], num_state_tokens: int
) -> str:
if state is None:
return ""
token_ids = _discretize_normalized_state(state, num_state_tokens).reshape(-1)
return f"{STATE_START_TOKEN}{''.join(f'{STATE_TOKEN_PREFIX}{int(token_id)}>' for token_id in token_ids)}{STATE_END_TOKEN}"
def _normalize_question_text(text: str) -> str:
normalized = re.sub(r"\s+", " ", text).strip()
if not normalized:
return ""
previous = None
while normalized and normalized != previous:
previous = normalized
normalized = normalized.strip().strip(_QUESTION_SURROUNDING_DELIMITERS).strip()
for pattern in _QUESTION_PREFIX_PATTERNS:
normalized = pattern.sub("", normalized, count=1).strip()
normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip()
normalized = normalized.rstrip(_QUESTION_TRAILING_CLOSERS).rstrip()
normalized = normalized.rstrip(_QUESTION_TRAILING_SENTENCE_PUNCTUATION).rstrip()
sentence_chunks = [
chunk.strip() for chunk in re.split(r"[.!?]+", normalized) if chunk.strip()
]
if len(sentence_chunks) > 1:
normalized = "; ".join(sentence_chunks)
normalized = normalized.lower()
return normalized
def _build_robot_text(
*,
task: str,
style: str,
discrete_state_string: str,
setup_type: str,
control_mode: str,
add_setup_tokens: bool,
add_control_tokens: bool,
num_images: int,
) -> str:
setup_text = _wrap_setup_text(setup_type, add_setup_tokens=add_setup_tokens)
control_text = _wrap_control_text(
control_mode, add_control_tokens=add_control_tokens
)
state_clause = (
f" The current state of the robot is {discrete_state_string}."
if discrete_state_string
else ""
)
if style == "robot_depth_action":
prompt = (
f"The task is to {task}. The setup is {setup_text}.{state_clause} "
f"The expected control mode is {control_text}. Given these, first predict the depth map of the main image "
"and then predict the action the robot should take to complete the task?"
)
trigger = f"{DEPTH_OUTPUT_TOKEN}{ACTION_OUTPUT_TOKEN}"
else:
prompt = (
f"The task is to {task}. The setup is {setup_text}.{state_clause} "
f"The expected control mode is {control_text}. Given these, what action should the robot take to complete the task?"
)
trigger = ACTION_OUTPUT_TOKEN
if num_images <= 0:
image_prefix = ""
elif num_images == 1:
image_prefix = "<|image|>"
else:
image_prefix = "".join(f"Image {idx + 1}<|image|>" for idx in range(num_images))
return f"{image_prefix}<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{trigger}"
def _flatten_generated_token_ids(token_ids: torch.Tensor) -> List[int]:
if token_ids.ndim == 3:
return [int(x) for x in token_ids[0, 0].detach().cpu().tolist()]
if token_ids.ndim == 2:
return [int(x) for x in token_ids[0].detach().cpu().tolist()]
if token_ids.ndim == 1:
return [int(x) for x in token_ids.detach().cpu().tolist()]
raise ValueError(
f"Unexpected generated token tensor shape {tuple(token_ids.shape)}"
)
def _extract_discrete_token_bins(
generated_ids: List[int],
start_token_id: int,
end_token_id: int,
token_id_to_bin: Dict[int, int],
) -> List[int]:
start_idx = None
end_idx = None
for idx, token_id in enumerate(generated_ids):
if token_id == start_token_id:
start_idx = idx
break
if start_idx is not None:
for idx in range(start_idx + 1, len(generated_ids)):
if generated_ids[idx] == end_token_id:
end_idx = idx
break
span_start = 0 if start_idx is None else start_idx + 1
span_end = len(generated_ids) if end_idx is None else end_idx
return [
int(token_id_to_bin[token_id])
for token_id in generated_ids[span_start:span_end]
if token_id in token_id_to_bin
]
@dataclass
class MolmoAct2ActionOutput(ModelOutput):
actions: Optional[torch.FloatTensor] = None
generated_token_ids: Optional[torch.LongTensor] = None
depth_bins: Optional[torch.LongTensor] = None
depth_cache: Optional[Dict[str, Any]] = None
@dataclass
class _DepthPrefix:
token_ids: torch.Tensor
depth_bins: torch.Tensor
full_input_ids: torch.Tensor
attention_mask: Optional[torch.Tensor]
encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]]
next_output: Any
past_key_values: Optional[Cache]
@dataclass
class MolmoAct2CausalLMOutputWithPast(ModelOutput):
"""
Base class for MolmoAct2 causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class MolmoAct2ModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for MolmoAct2 outputs, with hidden states and attentions.
Args:
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_num_patches, hidden_size)`.
image_hidden_states of the model produced by the vision backbone
"""
last_hidden_state: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
class ViTMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
hidden_act: str,
device: Union[str, torch.device] = None,
):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device)
self.act = ACT2FN[hidden_act]
self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(self.act(self.w1(x)))
class ViTMultiHeadDotProductAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_key_value_heads: int,
head_dim: int,
use_bias: bool = True,
input_dim: Optional[int] = None,
float32_attention: bool = True,
attention_dropout: float = 0.0,
residual_dropout: float = 0.0,
device: Union[str, torch.device] = None,
attn_implementation: str = "eager",
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attn_implementation = attn_implementation
self.is_causal = False
input_dim = input_dim or hidden_size
self.wq = nn.Linear(
input_dim,
self.num_heads * self.head_dim,
bias=use_bias,
device=device,
)
self.wk = nn.Linear(
input_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=device,
)
self.wv = nn.Linear(
input_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=device,
)
self.wo = nn.Linear(
self.num_heads * self.head_dim,
self.hidden_size,
)
self.float32_attention = float32_attention
self.attention_dropout = attention_dropout
self.residual_dropout = nn.Dropout(residual_dropout)
self.sdpa_backend_list = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
def _split_heads(self, hidden_states, num_heads) -> torch.Tensor:
return hidden_states.reshape(
hidden_states.shape[:2] + (num_heads, self.head_dim)
)
def _merge_heads(self, hidden_states) -> torch.Tensor:
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
def forward(
self,
inputs_q: torch.Tensor,
inputs_kv: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_kv is not None:
inputs_k = inputs_kv
inputs_v = inputs_kv
else:
inputs_k = inputs_q
inputs_v = inputs_q
xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
xq = self._split_heads(xq, self.num_heads)
xk = self._split_heads(xk, self.num_key_value_heads)
xv = self._split_heads(xv, self.num_key_value_heads)
if self.num_heads != self.num_key_value_heads:
xk = xk.repeat_interleave(
self.num_key_value_groups, dim=2, output_size=self.num_heads
)
xv = xv.repeat_interleave(
self.num_key_value_groups, dim=2, output_size=self.num_heads
)
og_dtype = xq.dtype
if self.float32_attention:
xq = xq.to(torch.float)
xk = xk.to(torch.float)
dropout_p = 0.0 if not self.training else self.attention_dropout
if self.attn_implementation == "eager":
attn_weights = torch.einsum(
"...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk
)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
xq.dtype
)
attn_weights = F.dropout(attn_weights, p=dropout_p, training=self.training)
attn_output = torch.einsum(
"...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv
)
elif self.attn_implementation == "sdpa":
if self.float32_attention:
xv = xv.to(torch.float32)
query = xq.transpose(1, 2).contiguous()
key = xk.transpose(1, 2).contiguous()
value = xv.transpose(1, 2).contiguous()
if inputs_kv is not None:
with sdpa_kernel(self.sdpa_backend_list):
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
is_causal=False,
dropout_p=dropout_p,
).transpose(1, 2)
else:
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
is_causal=False,
dropout_p=dropout_p,
).transpose(1, 2)
elif self.attn_implementation == "flash_attention_2":
if xq.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.wq.weight.dtype
attn_output = _flash_attention_forward(
xq,
xk,
xv,
attention_mask=attn_mask,
query_length=inputs_q.shape[1],
is_causal=False,
dropout=dropout_p,
softmax_scale=xq.shape[-1] ** -0.5,
use_top_left_mask=flash_attn_supports_top_left_mask(),
target_dtype=target_dtype,
implementation=self.attn_implementation,
)
else:
raise ValueError(
f"Attention implementation {self.attn_implementation} not supported"
)
attn_output = attn_output.to(og_dtype)
attn_output = self._merge_heads(attn_output)
attn_output = self.wo(attn_output)
attn_output = self.residual_dropout(attn_output)
return attn_output
class MolmoAct2VisionBlock(nn.Module):
def __init__(
self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None
):
super().__init__()
self.attention = ViTMultiHeadDotProductAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
float32_attention=config.float32_attention,
attention_dropout=config.attention_dropout,
residual_dropout=config.residual_dropout,
device=device,
attn_implementation=config._attn_implementation,
)
self.feed_forward = ViTMLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
device=device,
)
self.attention_norm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps, device=device
)
self.ffn_norm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps, device=device
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attention(self.attention_norm(x))
x = x + self.feed_forward(self.ffn_norm(x))
return x
class MolmoAct2VisionBlockCollection(nn.Module):
def __init__(
self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None
):
super().__init__()
self.conifg = config
self.resblocks = nn.ModuleList(
[
MolmoAct2VisionBlock(config, device)
for _ in range(config.num_hidden_layers)
]
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
hidden_states = []
for r in self.resblocks:
x = r(x)
hidden_states.append(x)
return hidden_states
class MolmoAct2VisionTransformer(nn.Module):
def __init__(
self, config: MolmoAct2VitConfig, device: Union[str, torch.device] = None
):
super().__init__()
self.config = config
# positional embeddings
self.scale = config.hidden_size**-0.5
self.num_prefix_tokens: int = 0 # no class embeddings
self.positional_embedding = nn.Parameter(
torch.zeros(config.image_num_pos, config.hidden_size, device=device),
)
image_patch_size = config.image_patch_size
self.patch_embedding = nn.Linear(
image_patch_size * image_patch_size * 3,
config.hidden_size,
bias=True,
device=device,
)
self.transformer = MolmoAct2VisionBlockCollection(config, device)
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
pos_emb = self.positional_embedding
pos_emb = pos_emb.reshape(
(
int(math.sqrt(pos_emb.shape[0])),
int(math.sqrt(pos_emb.shape[0])),
pos_emb.shape[1],
)
)
(patch_num_0, patch_num_1) = patch_num
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# antialias: default True in jax.image.resize
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
pos_emb = F.interpolate(
pos_emb,
size=(patch_num_0, patch_num_1),
mode="bicubic",
align_corners=False,
antialias=True,
)
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
x = x + pos_emb[None, :, :].to(x.dtype)
return x
def forward(self, x: torch.Tensor, patch_num: int = None) -> list[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
"""
if patch_num is None:
patch_num = self.config.image_num_patch
B, N, D = x.shape
x = self.patch_embedding(x)
# class embeddings and positional embeddings
x = self.add_pos_emb(x, patch_num)
hidden_states = self.transformer(x)
return hidden_states
class ImageProjectorMLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
hidden_act: str,
device: Union[str, torch.device] = None,
):
super().__init__()
self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device)
self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(self.act(self.w1(x)) * self.w3(x))
class MolmoAct2VisionBackbone(nn.Module):
def __init__(
self, vit_config: MolmoAct2VitConfig, adapter_config: MolmoAct2AdapterConfig
):
super().__init__()
self.vit_config = vit_config
self.adapter_config = adapter_config
self.vit_layers = []
for layer in adapter_config.vit_layers:
if layer >= 0:
self.vit_layers.append(layer)
else:
self.vit_layers.append(layer + vit_config.num_hidden_layers)
last_layer_needed = max(self.vit_layers) + 1
if last_layer_needed < vit_config.num_hidden_layers:
new_vit_config = deepcopy(vit_config)
new_vit_config.num_hidden_layers = last_layer_needed
self.image_vit = MolmoAct2VisionTransformer(new_vit_config)
else:
self.image_vit = MolmoAct2VisionTransformer(vit_config)
self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens
pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
self.image_pooling_2d = ViTMultiHeadDotProductAttention(
hidden_size=adapter_config.hidden_size,
num_heads=adapter_config.num_attention_heads,
num_key_value_heads=adapter_config.num_key_value_heads,
head_dim=adapter_config.head_dim,
input_dim=pool_dim,
float32_attention=adapter_config.float32_attention,
attention_dropout=adapter_config.attention_dropout,
residual_dropout=adapter_config.residual_dropout,
attn_implementation=adapter_config._attn_implementation,
)
self.image_projector = ImageProjectorMLP(
adapter_config.hidden_size,
adapter_config.intermediate_size,
adapter_config.text_hidden_size,
adapter_config.hidden_act,
)
self.image_feature_dropout = nn.Dropout(adapter_config.image_feature_dropout)
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
"""
: param images: (batch_size, num_crops, num_patch, n_pixels)
"""
B, T, N, D = images.shape
images = images.view(B * T, N, D)
image_features = self.image_vit(images)
features = []
for layer in self.vit_layers:
features.append(image_features[layer])
image_features = torch.cat(features, dim=-1)
if self.num_prefix_tokens > 0:
image_features = image_features[:, 1:]
image_features = image_features.view(B, T, N, -1)
return image_features
@property
def dtype(self) -> torch.dtype:
return self.image_vit.patch_embedding.weight.dtype
@property
def device(self) -> torch.device:
return self.image_vit.patch_embedding.weight.device
def forward(
self,
images: torch.Tensor,
pooled_patches_idx: torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
batch_size, num_image = images.shape[:2]
images = images.to(device=self.device)
if images.dtype == torch.uint8:
images = images.to(dtype=torch.float32) / 255.0
images = images * 2.0 - 1.0
elif torch.is_floating_point(images):
# Native MolmoAct2 eval keeps resized SigLIP pixels as uint8 and normalizes
# on device. Canonicalize HF processor floats to that exact grid.
images = torch.round(((images.to(dtype=torch.float32) + 1.0) * 0.5) * 255.0)
images = torch.clamp(images, 0.0, 255.0) / 255.0
images = images * 2.0 - 1.0
images = images.to(dtype=self.dtype)
image_features = self.encode_image(images)
image_features = self.image_feature_dropout(image_features)
dim = image_features.shape[-1]
valid = pooled_patches_idx >= 0
valid_token = torch.any(valid, -1)
# Use `pooled_patches_idx` to arange the features for image pooling
batch_idx = torch.arange(
pooled_patches_idx.shape[0],
dtype=torch.long,
device=pooled_patches_idx.device,
)
batch_idx = torch.tile(
batch_idx.view(batch_size, 1, 1),
[1, pooled_patches_idx.shape[1], pooled_patches_idx.shape[2]],
)
# Now [batch, num_high_res_features, pool_dim, dim]
to_pool = image_features.reshape(batch_size, -1, dim)[
batch_idx, torch.clip(pooled_patches_idx, 0)
]
to_pool = to_pool * valid.to(self.dtype)[:, :, :, None]
to_pool = to_pool.reshape([-1, pooled_patches_idx.shape[-1], dim])
if self.adapter_config.pooling_attention_mask:
attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]])
denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1)
denom = torch.where(denom == 0, 1, denom)
query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(
to_pool.dtype
)
else:
attn_mask = None
query = to_pool.mean(-2, keepdim=True)
pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
pooled_features = pooled_features.reshape(
[batch_size, -1, pooled_features.shape[-1]]
)
# MLP layer to map the feature.
pooled_features = self.image_projector(pooled_features)
return pooled_features.view(-1, pooled_features.shape[-1])[
valid_token.flatten()
]
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MolmoAct2RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(
self,
config: MolmoAct2TextConfig,
device: Union[str, torch.device] = None,
rope_type: Optional[str] = None,
):
super().__init__()
if rope_type is not None:
self.rope_type = rope_type
elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
# BC: "rope_type" was originally "type"
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
if self.rope_type == "default":
self.rope_init_fn = self._default_rope_init
else:
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=True)
self.original_inv_freq = self.inv_freq
self.register_buffer("_pos_sin_cache", torch.empty(0), persistent=False)
self.register_buffer("_pos_cos_cache", torch.empty(0), persistent=False)
@staticmethod
def _default_rope_init(
config: MolmoAct2TextConfig, device: Union[str, torch.device] = None, **_
) -> tuple[torch.Tensor, float]:
inv_freq = 1.0 / (
config.rope_theta
** (
torch.arange(0, config.head_dim, 2, dtype=torch.float32, device=device)
/ config.head_dim
)
)
return inv_freq, 1.0
def _target_cache_seq_len(
self, x: torch.Tensor, position_ids: Optional[torch.Tensor]
) -> int:
if self.config.max_position_embeddings:
return int(self.config.max_position_embeddings)
if position_ids is not None:
return int(position_ids.max().item()) + 1
return int(x.shape[-2])
def _rope_cache_ready(self, device: torch.device, seq_len: int) -> bool:
return (
self._pos_sin_cache.numel() > 0
and self._pos_sin_cache.device == device
and self._pos_cos_cache.device == device
and self._pos_sin_cache.shape[-2] >= seq_len
and self._pos_cos_cache.shape[-2] >= seq_len
)
def _refresh_inv_freq_if_needed(self, device: torch.device) -> None:
device = torch.device(device)
expected = int(self.config.head_dim) // 2
needs_refresh = (
self.inv_freq is None
or self._pos_sin_cache.numel() == 0
or self.inv_freq.device.type == "meta"
or self.inv_freq.device != device
or self.inv_freq.numel() != expected
)
if not needs_refresh:
inv_freq_cpu = self.inv_freq.detach()
needs_refresh = (
not bool(torch.isfinite(inv_freq_cpu).all().item())
or bool((inv_freq_cpu <= 0).any().item())
or not bool(
torch.isclose(inv_freq_cpu[0].cpu(), torch.tensor(1.0)).item()
)
)
if needs_refresh:
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=True)
self.original_inv_freq = self.inv_freq
self._pos_sin_cache = torch.empty(0, device=device)
self._pos_cos_cache = torch.empty(0, device=device)
def _build_rope_cache(self, device: torch.device, seq_len: int) -> None:
device_type = device.type if device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
seq = torch.arange(seq_len, device=device, dtype=torch.float)
freqs = torch.einsum(
"i,j->ij", seq, self.inv_freq.to(device=device, dtype=torch.float)
)
emb = torch.cat((freqs, freqs), dim=-1)
self._pos_sin_cache = emb.sin()[None, None, :, :] * self.attention_scaling
self._pos_cos_cache = emb.cos()[None, None, :, :] * self.attention_scaling
@torch.no_grad()
def prepare_rope_cache(
self,
*,
device: Union[str, torch.device],
max_seq_len: Optional[int] = None,
) -> None:
if self.rope_type != "default":
return
device = torch.device(device)
seq_len = int(max_seq_len or self.config.max_position_embeddings or 0)
if seq_len <= 0:
raise ValueError(
"RoPE cache preparation requires a positive max sequence length."
)
if self._rope_cache_ready(device, seq_len):
return
self._refresh_inv_freq_if_needed(device)
self._build_rope_cache(device, seq_len)
def _select_rope_cache(
self,
x: torch.Tensor,
position_ids: Optional[torch.Tensor],
seq_len: int,
) -> tuple[torch.Tensor, torch.Tensor]:
pos_sin = self._pos_sin_cache[:, :, :seq_len, :]
pos_cos = self._pos_cos_cache[:, :, :seq_len, :]
if position_ids is None:
sin = pos_sin[0, 0, : x.shape[-2], :]
cos = pos_cos[0, 0, : x.shape[-2], :]
else:
sin = pos_sin[0, 0][position_ids].view(
position_ids.shape + (pos_sin.shape[-1],)
)
cos = pos_cos[0, 0][position_ids].view(
position_ids.shape + (pos_cos.shape[-1],)
)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(
self, x, position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = self._target_cache_seq_len(x, position_ids)
if not self._rope_cache_ready(x.device, seq_len):
self._refresh_inv_freq_if_needed(x.device)
self._build_rope_cache(x.device, seq_len)
return self._select_rope_cache(x, position_ids, seq_len)
class MolmoAct2RMSNorm(nn.Module):
def __init__(
self,
size: int,
eps: float = 1e-6,
device: Union[str, torch.device] = None,
):
super().__init__()
self.weight = nn.Parameter(torch.ones(size, device=device))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.autocast(enabled=False, device_type=x.device.type):
og_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
x = x.to(og_dtype)
return self.weight * x
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query.dtype
)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class MolmoAct2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: MolmoAct2TextConfig, layer_idx: int) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.head_dim = config.head_dim
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.fused_dims = (
config.num_attention_heads * config.head_dim,
config.head_dim * config.num_key_value_heads,
config.head_dim * config.num_key_value_heads,
)
self.att_proj = nn.Linear(
config.hidden_size,
sum(self.fused_dims),
bias=config.qkv_bias,
)
# Layer norms.
self.k_norm: Optional[MolmoAct2RMSNorm] = None
self.q_norm: Optional[MolmoAct2RMSNorm] = None
self.qk_norm_type: Optional[str] = None
if config.use_qk_norm:
k_norm_size = (
config.head_dim
if config.qk_norm_type == "qwen3"
else config.num_key_value_heads * config.head_dim
)
self.k_norm = MolmoAct2RMSNorm(k_norm_size, eps=config.layer_norm_eps)
q_norm_size = (
config.head_dim
if config.qk_norm_type == "qwen3"
else config.num_attention_heads * config.head_dim
)
self.q_norm = MolmoAct2RMSNorm(q_norm_size, eps=config.layer_norm_eps)
self.qk_norm_type = config.qk_norm_type
self.attention_dropout = config.attention_dropout
self.attn_out = nn.Linear(
config.head_dim * config.num_attention_heads,
config.hidden_size,
bias=False,
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.att_proj(hidden_states)
query_states, key_states, value_states = qkv.split(self.fused_dims, dim=-1)
value_states = value_states.view(hidden_shape)
# Optionally apply layer norm to keys and queries.
if (
self.q_norm is not None
and self.k_norm is not None
and self.qk_norm_type != "qwen3"
):
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape)
if (
self.q_norm is not None
and self.k_norm is not None
and self.qk_norm_type == "qwen3"
):
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
dropout_p = 0.0 if not self.training else self.attention_dropout
if self.config._attn_implementation == "sdpa" and (
attention_mask is None or torch.is_tensor(attention_mask)
):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=dropout_p,
is_causal=attention_mask is None,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_weights = None
else:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=dropout_p,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.attn_out(attn_output)
return attn_output, attn_weights
class LanguageModelMLP(nn.Module):
def __init__(
self,
input_dim: int,
intermediate_size: int,
hidden_act: str,
device: Union[str, torch.device] = None,
):
super().__init__()
self.ff_proj = nn.Linear(
input_dim, intermediate_size * 2, bias=False, device=device
)
self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device)
self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ff_proj(x)
x, gate = x.chunk(2, dim=-1)
x = self.act(gate) * x
x = self.ff_out(x)
return x
class MolmoAct2DecoderLayer(GradientCheckpointingLayer):
def __init__(
self,
config: MolmoAct2TextConfig,
layer_idx: Optional[int] = None,
device: Union[str, torch.device] = None,
):
super().__init__()
self.config = config
self.self_attn = MolmoAct2Attention(config, layer_idx)
self.attn_norm = MolmoAct2RMSNorm(
config.hidden_size, eps=config.layer_norm_eps, device=device
)
self.dropout = nn.Dropout(config.residual_dropout)
self.mlp = LanguageModelMLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
device=device,
)
self.ff_norm = MolmoAct2RMSNorm(
config.hidden_size, eps=config.layer_norm_eps, device=device
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + self.dropout(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.ff_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.dropout(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class MolmoAct2PostNormDecoderLayer(MolmoAct2DecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = self.attn_norm(hidden_states)
hidden_states = residual + self.dropout(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.ff_norm(hidden_states)
hidden_states = residual + self.dropout(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class MolmoAct2Embedding(nn.Module):
def __init__(
self,
num_embeddings: int,
num_new_embeddings: int,
features: int,
device: Union[str, torch.device] = None,
):
super().__init__()
self.embedding = nn.Parameter(
torch.zeros(num_embeddings, features, device=device),
)
self.new_embedding = nn.Parameter(
torch.zeros(num_new_embeddings, features, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
class MolmoAct2PreTrainedModel(PreTrainedModel):
config: MolmoAct2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = [
"MolmoAct2DecoderLayer",
"MolmoAct2PostNormDecoderLayer",
"MolmoAct2VisionBlock",
"ViTMultiHeadDotProductAttention",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": MolmoAct2DecoderLayer,
"attentions": MolmoAct2Attention,
}
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear,)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, MolmoAct2Embedding):
module.embedding.data.normal_(mean=0.0, std=std)
module.new_embedding.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MolmoAct2RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
if module.bias is not None:
module.bias.data.zero_()
class MolmoAct2TextModel(MolmoAct2PreTrainedModel):
config: MolmoAct2TextConfig
_no_split_modules = ["MolmoAct2DecoderLayer", "MolmoAct2PostNormDecoderLayer"]
def __init__(self, config: MolmoAct2TextConfig):
super().__init__(config)
if config.additional_vocab_size is not None:
self.wte = MolmoAct2Embedding(
config.vocab_size,
config.additional_vocab_size,
config.hidden_size,
)
else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.emb_drop = nn.Dropout(config.embedding_dropout)
decoder_layer = (
MolmoAct2PostNormDecoderLayer
if config.norm_after
else MolmoAct2DecoderLayer
)
self.blocks = nn.ModuleList(
[
decoder_layer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.ln_f = MolmoAct2RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
if config.rope_scaling_layers is not None:
self.rotary_embs = nn.ModuleDict(
{
"default": MolmoAct2RotaryEmbedding(config, rope_type="default"),
"scaling": MolmoAct2RotaryEmbedding(config),
}
)
else:
self.rotary_emb = MolmoAct2RotaryEmbedding(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@torch.no_grad()
def prepare_rope_cache(
self,
*,
device: Union[str, torch.device],
max_seq_len: Optional[int] = None,
) -> None:
if self.config.rope_scaling_layers is not None:
for rotary_emb in self.rotary_embs.values():
rotary_emb.prepare_rope_cache(device=device, max_seq_len=max_seq_len)
return
self.rotary_emb.prepare_rope_cache(device=device, max_seq_len=max_seq_len)
def get_input_embeddings(self) -> torch.nn.Module:
return self.wte
def set_input_embeddings(self, value: torch.nn.Module) -> None:
self.wte = value
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
inputs_embeds = self.wte(input_ids)
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
if torch.is_tensor(attention_mask) and attention_mask.ndim == 4:
causal_mask_mapping = attention_mask
elif not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the mask
causal_mask_mapping = create_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
if self.config.rope_scaling_layers is not None:
position_embeddings_mapping = {
"default": self.rotary_embs["default"](hidden_states, position_ids),
"scaling": self.rotary_embs["scaling"](hidden_states, position_ids),
}
else:
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for layer_idx, decoder_block in enumerate(
self.blocks[: self.config.num_hidden_layers]
):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.config.rope_scaling_layers is not None:
position_embeddings_i = (
position_embeddings_mapping["scaling"]
if layer_idx in self.config.rope_scaling_layers
else position_embeddings_mapping["default"]
)
else:
position_embeddings_i = position_embeddings
layer_outputs = decoder_block(
hidden_states,
attention_mask=causal_mask_mapping,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings_i,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.ln_f(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Adapted from transformers.models.gemma3.modeling_gemma3
def token_type_ids_mask_function(
token_type_ids: Optional[torch.Tensor] = None,
) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
"""
# Do not return an additional mask in this case
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1 for both query and key/value, we are in an image block
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
token_type_ids_at_kv_idx = torch.where(
kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0
)
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (
token_type_ids_at_kv_idx == 1
)
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & is_image_block
return inner_mask
class MolmoAct2Model(MolmoAct2PreTrainedModel):
base_model_prefix = ""
_checkpoint_conversion_mapping = {}
# Reference: fix gemma3 grad acc #37208
accepts_loss_kwargs = False
config: MolmoAct2Config
def __init__(self, config: MolmoAct2Config):
super().__init__(config)
self.transformer: MolmoAct2TextModel = MolmoAct2TextModel(config.text_config)
self.vision_backbone: Optional[MolmoAct2VisionBackbone] = None
if config.vit_config is not None and config.adapter_config is not None:
self.vision_backbone = MolmoAct2VisionBackbone(
config.vit_config, config.adapter_config
)
llm_kv_dim = (
config.text_config.num_key_value_heads * config.text_config.head_dim
)
if config.add_action_expert:
self.action_expert = ActionExpert(
config.action_expert_config,
llm_dim=config.hidden_size,
llm_kv_dim=llm_kv_dim,
llm_num_layers=config.num_hidden_layers,
)
else:
self.action_expert = None
if config.add_action_expert and config.action_expert_depth_gate:
if config.action_expert_depth_gate_per_layer:
self.action_expert_depth_gate = nn.ModuleList(
nn.Linear(llm_kv_dim, 1)
for _ in range(config.action_expert_config.num_layers)
)
else:
self.action_expert_depth_gate = nn.Linear(llm_kv_dim, 1)
self.reset_action_expert_depth_gate_parameters()
else:
self.action_expert_depth_gate = None
self._depth_gate_token_ids = self._resolve_depth_gate_token_ids()
self.action_cuda_graph_manager: Optional[ActionCudaGraphManager] = None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> torch.nn.Module:
return self.transformer.wte
def set_input_embeddings(self, value: torch.nn.Module) -> None:
self.transformer.wte = value
def set_decoder(self, decoder):
self.transformer = decoder
def get_decoder(self):
return self.transformer
@property
def device(self) -> torch.device:
return self.transformer.ln_f.weight.device
def reset_action_expert_depth_gate_parameters(self) -> None:
if self.action_expert_depth_gate is None:
return
gates = (
self.action_expert_depth_gate
if isinstance(self.action_expert_depth_gate, nn.ModuleList)
else [self.action_expert_depth_gate]
)
for gate in gates:
nn.init.zeros_(gate.weight)
nn.init.constant_(
gate.bias, float(self.config.action_expert_depth_gate_init_bias)
)
def _resolve_depth_gate_token_ids(self) -> Tuple[int, ...]:
if not self.config.action_expert_depth_gate:
return ()
token_ids = []
for token_id in (
self.config.depth_output_token_id,
self.config.depth_start_token_id,
self.config.depth_end_token_id,
):
if token_id is not None:
token_ids.append(int(token_id))
if (
self.config.depth_token_start_id is not None
and int(self.config.num_depth_tokens or 0) > 0
):
start = int(self.config.depth_token_start_id)
token_ids.extend(range(start, start + int(self.config.num_depth_tokens)))
return tuple(dict.fromkeys(token_ids))
def _require_action_expert(self) -> ActionExpert:
if self.action_expert is None:
raise RuntimeError(
"This MolmoAct2 checkpoint does not include an action expert."
)
return self.action_expert
def _cache_to_sequence(self, cache: torch.Tensor) -> torch.Tensor:
if cache.dim() != 4:
raise ValueError(
f"Expected KV cache tensor with 4 dims, got shape {tuple(cache.shape)}"
)
head_candidates = {
self.config.text_config.num_key_value_heads,
self.config.text_config.num_attention_heads,
}
if cache.shape[1] in head_candidates:
bsz, n_heads, seq_len, head_dim = cache.shape
return cache.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * head_dim)
if cache.shape[2] in head_candidates:
bsz, seq_len, n_heads, head_dim = cache.shape
return cache.reshape(bsz, seq_len, n_heads * head_dim)
if cache.shape[1] <= cache.shape[2]:
bsz, n_heads, seq_len, head_dim = cache.shape
return cache.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * head_dim)
bsz, seq_len, n_heads, head_dim = cache.shape
return cache.reshape(bsz, seq_len, n_heads * head_dim)
def _extract_kv_states(
self, past_key_values: Cache
) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]:
if past_key_values is None:
raise RuntimeError(
"Action generation requires past_key_values from the VLM forward pass."
)
seq_len = _cache_seq_len_int(past_key_values)
kv_states = []
for key, value in _iter_cache_key_values(past_key_values):
if key is None or value is None:
continue
if key.shape[-2] > seq_len:
key = key[..., :seq_len, :]
value = value[..., :seq_len, :]
kv_states.append(
(self._cache_to_sequence(key), self._cache_to_sequence(value))
)
if len(kv_states) != self.config.action_expert_config.num_layers:
raise RuntimeError(
f"Expected {self.config.action_expert_config.num_layers} KV layers, got {len(kv_states)}."
)
return kv_states
@staticmethod
def _mask_discrete_output_span(
row_ids: torch.Tensor,
row_mask: torch.Tensor,
start_id: Optional[int],
end_id: Optional[int],
) -> None:
if start_id is None or end_id is None:
return
start_positions = (
(row_ids == start_id).nonzero(as_tuple=False).flatten().tolist()
)
if not start_positions:
return
end_positions = (row_ids == end_id).nonzero(as_tuple=False).flatten().tolist()
end_ptr = 0
for start_pos in start_positions:
while end_ptr < len(end_positions) and end_positions[end_ptr] < start_pos:
end_ptr += 1
if end_ptr >= len(end_positions):
row_mask[start_pos:] = False
break
end_pos = end_positions[end_ptr]
row_mask[start_pos : end_pos + 1] = False
end_ptr += 1
def _get_encoder_attention_mask(
self,
input_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
if attention_mask is not None:
mask = attention_mask.to(dtype=torch.bool).clone()
elif input_ids is not None:
mask = input_ids != -1
else:
return None
if self.config.action_format != "both" or input_ids is None:
return mask
eos_id = getattr(self.config, "eos_token_id", None)
if eos_id is not None:
mask &= input_ids != int(eos_id)
for batch_idx in range(input_ids.shape[0]):
self._mask_discrete_output_span(
input_ids[batch_idx],
mask[batch_idx],
self.config.action_start_token_id,
self.config.action_end_token_id,
)
return mask
def _get_depth_token_mask(
self,
input_ids: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
if (
not self.config.action_expert_depth_gate
or input_ids is None
or not self._depth_gate_token_ids
):
return None
depth_token_ids = torch.as_tensor(
self._depth_gate_token_ids,
device=input_ids.device,
dtype=input_ids.dtype,
)
depth_mask = (input_ids.unsqueeze(-1) == depth_token_ids).any(dim=-1)
if encoder_attention_mask is not None:
depth_mask = depth_mask & encoder_attention_mask.to(
device=input_ids.device, dtype=torch.bool
)
return depth_mask
@staticmethod
def _depth_gate_from_source(
gate_head: nn.Linear,
*,
source: torch.Tensor,
depth_mask: torch.Tensor,
encoder_attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
if source.ndim == 4:
source = source.reshape(source.shape[0], source.shape[1], -1)
if source.ndim != 3:
raise ValueError(
f"Depth gate expected a 3D sequence tensor, got {tuple(source.shape)}."
)
if encoder_attention_mask is not None:
valid_mask = encoder_attention_mask.to(
device=source.device, dtype=torch.bool
)
else:
valid_mask = torch.ones(
depth_mask.shape, device=source.device, dtype=torch.bool
)
depth_mask = depth_mask.to(device=source.device, dtype=torch.bool)
pool_mask = valid_mask & ~depth_mask
has_pool = pool_mask.any(dim=-1, keepdim=True)
pool_mask = torch.where(has_pool, pool_mask, valid_mask)
weights = pool_mask.to(dtype=source.dtype).unsqueeze(-1)
pooled = (source * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0)
gate_logits = gate_head(pooled.to(dtype=gate_head.weight.dtype))
return torch.sigmoid(gate_logits).to(dtype=source.dtype)
def _depth_gate_from_condition(
self,
*,
input_ids: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.Tensor],
layer_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]],
) -> Tuple[
Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], Optional[torch.Tensor]
]:
gate_head = self.action_expert_depth_gate
if gate_head is None:
return None, None
depth_mask = self._get_depth_token_mask(input_ids, encoder_attention_mask)
if depth_mask is None or layer_kv_states is None:
return None, depth_mask
sources = [value for _, value in layer_kv_states]
if isinstance(gate_head, nn.ModuleList):
if len(gate_head) != len(sources):
raise ValueError(
f"Depth gate layer count mismatch: gates={len(gate_head)}, sources={len(sources)}."
)
gates = [
self._depth_gate_from_source(
gate,
source=source,
depth_mask=depth_mask,
encoder_attention_mask=encoder_attention_mask,
)
for gate, source in zip(gate_head, sources)
]
return gates, depth_mask
gate = self._depth_gate_from_source(
gate_head,
source=sources[-1],
depth_mask=depth_mask,
encoder_attention_mask=encoder_attention_mask,
)
return gate, depth_mask
@staticmethod
def _depth_gate_for_layer(
gate: Union[torch.Tensor, Sequence[torch.Tensor]],
layer_idx: int,
*,
num_layers: int,
) -> torch.Tensor:
if isinstance(gate, torch.Tensor):
return gate
if len(gate) != num_layers:
raise ValueError(
f"Depth gate layer count mismatch: gates={len(gate)}, layers={num_layers}."
)
return gate[layer_idx]
def _apply_depth_gate_to_layer_kv_states(
self,
layer_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]],
depth_mask: Optional[torch.Tensor],
gate: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
) -> Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]]:
if layer_kv_states is None or depth_mask is None or gate is None:
return layer_kv_states
gated_kv = []
for layer_idx, (key, value) in enumerate(layer_kv_states):
layer_gate = self._depth_gate_for_layer(
gate, layer_idx, num_layers=len(layer_kv_states)
)
mask = depth_mask.to(device=key.device, dtype=torch.bool)
view_shape = [mask.shape[0], mask.shape[1]] + [1] * (key.ndim - 2)
scale = torch.ones(view_shape, device=key.device, dtype=key.dtype)
gate_view = layer_gate.to(device=key.device, dtype=key.dtype).view(
layer_gate.shape[0],
*([1] * (key.ndim - 1)),
)
scale = torch.where(mask.view(view_shape), gate_view, scale)
gated_kv.append((key * scale, value * scale))
return gated_kv
@staticmethod
def _action_dim_valid_mask(
target: torch.Tensor,
action_dim_is_pad: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
if action_dim_is_pad is None:
return None
mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
if mask.shape[-1] != target.shape[-1]:
raise ValueError(
f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}."
)
if mask.shape[0] == 1 and target.shape[0] != 1:
mask = mask.expand(target.shape[0], -1)
if mask.shape[0] != target.shape[0]:
raise ValueError(
f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}."
)
while mask.ndim < target.ndim:
mask = mask.unsqueeze(1)
return mask
@classmethod
def _mask_action_dim_tensor(
cls,
tensor: torch.Tensor,
*,
action_dim_is_pad: Optional[torch.Tensor],
enabled: bool,
) -> torch.Tensor:
if not enabled:
return tensor
valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad)
if valid_mask is None:
return tensor
return tensor.masked_fill(~valid_mask, 0)
def _run_action_flow_loop(
self, inputs: _ActionFlowInputs, steps: int
) -> torch.Tensor:
action_expert = self._require_action_expert()
dt = 1.0 / steps
trajectory = inputs.trajectory
action_dim_is_pad = inputs.action_dim_is_pad
mask_enabled = self.config.mask_action_dim_padding
for idx in range(steps):
velocity = action_expert.forward_with_context(
trajectory,
inputs.modulations[idx].conditioning,
context=inputs.context,
modulation=inputs.modulations[idx],
)
velocity = self._mask_action_dim_tensor(
velocity,
action_dim_is_pad=action_dim_is_pad,
enabled=mask_enabled,
)
trajectory = trajectory + dt * velocity
trajectory = self._mask_action_dim_tensor(
trajectory,
action_dim_is_pad=action_dim_is_pad,
enabled=mask_enabled,
)
return trajectory
@torch.no_grad()
def generate_actions_from_inputs(
self,
*,
input_ids: torch.LongTensor,
pixel_values: Optional[torch.Tensor] = None,
image_token_pooling: Optional[torch.Tensor] = None,
image_grids: Optional[torch.Tensor] = None,
image_num_crops: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
video_token_pooling: Optional[torch.Tensor] = None,
video_grids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
states: Optional[torch.Tensor] = None,
action_dim_is_pad: Optional[torch.Tensor] = None,
num_steps: Optional[int] = None,
generator: Optional[torch.Generator] = None,
encoder_kv_states: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
action_expert = self._require_action_expert()
if encoder_kv_states is None:
outputs = self(
input_ids=input_ids,
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
use_cache=True,
)
encoder_kv_states = self._extract_kv_states(outputs.past_key_values)
encoder_attention_mask = self._get_encoder_attention_mask(
input_ids, attention_mask
)
elif encoder_attention_mask is None:
encoder_attention_mask = self._get_encoder_attention_mask(
input_ids, attention_mask
)
depth_gate, depth_mask = self._depth_gate_from_condition(
input_ids=input_ids,
encoder_attention_mask=encoder_attention_mask,
layer_kv_states=encoder_kv_states,
)
encoder_kv_states = self._apply_depth_gate_to_layer_kv_states(
encoder_kv_states,
depth_mask,
depth_gate,
)
steps = int(num_steps or self.config.flow_matching_num_steps)
if steps <= 0:
raise ValueError(f"num_steps must be >= 1, got {steps}.")
source_tensor = encoder_kv_states[0][0]
batch_size = source_tensor.shape[0]
device = source_tensor.device
trajectory = torch.randn(
(batch_size, self.config.action_horizon, self.config.max_action_dim),
device=device,
dtype=torch.float32,
generator=generator,
)
trajectory = self._mask_action_dim_tensor(
trajectory,
action_dim_is_pad=action_dim_is_pad,
enabled=self.config.mask_action_dim_padding,
)
action_context = action_expert.prepare_context(
encoder_kv_states=encoder_kv_states,
encoder_attention_mask=encoder_attention_mask,
state_embeddings=states,
batch_size=batch_size,
seq_len=trajectory.shape[1],
device=device,
dtype=trajectory.dtype,
)
flow_timesteps = [
torch.full(
(batch_size,), idx / steps, device=device, dtype=trajectory.dtype
)
for idx in range(steps)
]
modulation_cache = action_expert.get_or_prepare_modulation_cache(
flow_timesteps,
cache_key=(steps, batch_size, device, trajectory.dtype),
)
flow_inputs = _ActionFlowInputs(
trajectory=trajectory,
context=action_context,
modulations=modulation_cache,
action_dim_is_pad=action_dim_is_pad,
)
action_cuda_graph_manager = self.action_cuda_graph_manager
if (
action_cuda_graph_manager is not None
and action_cuda_graph_manager.can_use_action_flow(flow_inputs)
):
trajectory = action_cuda_graph_manager.run_action_flow(
flow_inputs, steps, self._run_action_flow_loop
)
else:
trajectory = self._run_action_flow_loop(flow_inputs, steps)
return trajectory
def build_batched_images(
self,
input_ids: torch.LongTensor,
pixel_values: torch.Tensor,
image_token_pooling: torch.Tensor,
image_grids: torch.Tensor,
image_num_crops: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# 1) Count the number of images in each example
raw_counts = (input_ids == self.config.image_end_token_id).sum(1) # [N]
total_images = int(image_grids.size(0))
total_end_tokens = int(raw_counts.sum().item())
if total_images <= 0:
counts = raw_counts.new_zeros(raw_counts.shape)
elif total_end_tokens == total_images:
counts = raw_counts
elif total_end_tokens == 2 * total_images:
counts = raw_counts // 2
else:
raise ValueError(
"Could not infer image counts from image end tokens: "
f"end_tokens={total_end_tokens}, image_grids={total_images}."
)
N = counts.size(0)
device = input_ids.device
# Total number of images in the batch
num_images = total_images
# Sanity check
assert image_grids.size(0) == num_images, (
f"Expected {num_images} image grids, but got {image_grids.size(0)}"
)
assert image_num_crops.size(0) == num_images, (
f"Expected {num_images} image num crops, but got {image_num_crops.size(0)}"
)
# 1-1) Compute per-image pooled patch count from image grids
with torch.no_grad():
first_prod = image_grids[:, :2].prod(dim=1) # [num_images]
second_prod = image_grids[:, 2:].prod(dim=1) # [num_images]
num_pooled_patches_per_image = (first_prod + second_prod).to(
image_num_crops.dtype
) # [num_images]
# pixel_values: [n_crops, n_patches, pixels_per_patch]
n_crops, n_patches, pixels_per_patch = pixel_values.shape
# 2) Map each image index → example index
# Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2]
example_ids_for_image = torch.arange(N, device=device).repeat_interleave(
counts
) # [num_images]
assert example_ids_for_image.numel() == num_images
# 2-1) Compute crops_per_example by summing per-image crop counts
crops_per_example = torch.zeros(
N, dtype=image_num_crops.dtype, device=image_num_crops.device
)
crops_per_example.index_add_(0, example_ids_for_image, image_num_crops) # [N]
# 2-2) Per-image number of patches = (crops per image) * n_patches
patches_per_image = image_num_crops * n_patches # [num_images]
# 2-3) Compute per-example per-image patch offsets
counts_list = counts.tolist()
index_offset_per_example_list = []
offset_img = 0
for c in counts_list:
per_img_patches = patches_per_image[offset_img : offset_img + c] # [c]
# Offsets: [0, img0_total_patches, img0+img1_total_patches, ...]
index_offset = [0] + per_img_patches.cumsum(0).tolist()[:-1]
index_offset_per_example_list.append(index_offset)
offset_img += c
# 2-4) Compute num_pooled_patches_per_example
num_pooled_patches_per_example = torch.zeros(
N,
dtype=num_pooled_patches_per_image.dtype,
device=num_pooled_patches_per_image.device,
)
num_pooled_patches_per_example.index_add_(
0, example_ids_for_image, num_pooled_patches_per_image
)
# Sanity checks
total_crops = int(crops_per_example.sum().item())
assert total_crops == n_crops, (
f"Expected {total_crops} crops, but got {n_crops}"
)
total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item())
assert total_num_pooled_patches == image_token_pooling.size(0), (
f"Expected {total_num_pooled_patches} pooled patches, but got {image_token_pooling.size(0)}"
)
# 3) Build images tensor filled with -1
M = int(crops_per_example.max().item())
images = torch.full(
(N, M, n_patches, pixels_per_patch),
fill_value=-1,
dtype=pixel_values.dtype,
device=pixel_values.device,
)
# 4) Fill images with per-example slices from pixel_values
offset_crop = 0
for i in range(N):
num = int(crops_per_example[i].item())
cur = pixel_values[
offset_crop : offset_crop + num
] # [num, n_patches, pixels_per_patch]
images[i, :num] = cur
offset_crop += num
# Sanity check
assert offset_crop == n_crops
# 5) Build new_token_pooling tensor filled with -1
P = int(num_pooled_patches_per_example.max().item())
_, dim = image_token_pooling.shape
new_token_pooling = torch.full(
(N, P, dim),
fill_value=-1,
dtype=image_token_pooling.dtype,
device=image_token_pooling.device,
)
# 6) Fill token_pooling with per-example slices, adding per-image patch offsets
patch_offset = 0
img_offset = 0
for i, c in enumerate(counts_list):
num_patches = int(num_pooled_patches_per_example[i].item())
# Subsequence of pooled tokens belonging to this example
cur = image_token_pooling[
patch_offset : patch_offset + num_patches
].clone() # [num_patches, dim]
index_offset_per_example = index_offset_per_example_list[i] # length = c
per_img_pooled = num_pooled_patches_per_image[
img_offset : img_offset + c
] # [c]
assert len(index_offset_per_example) == per_img_pooled.numel()
# Apply per-image offsets to the (ragged) subsequence
offset = 0
for j in range(c):
index_offset = int(index_offset_per_example[j])
n = int(per_img_pooled[j].item())
cur_slice = cur[offset : offset + n]
# Apply offset across all columns
cur[offset : offset + n] = torch.where(
cur_slice >= 0,
cur_slice + index_offset,
cur_slice,
)
offset += n
new_token_pooling[i, :num_patches] = cur
patch_offset += num_patches
img_offset += c
# Final sanity checks
assert patch_offset == total_num_pooled_patches
assert img_offset == num_images
return images, new_token_pooling
def build_batched_videos(
self,
input_ids: torch.LongTensor,
pixel_values_videos: torch.Tensor,
video_token_pooling: torch.Tensor,
video_grids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# 1) Count the number of videos in each example
if self.config.use_frame_special_tokens:
end_token_id = self.config.frame_end_token_id
else:
end_token_id = self.config.image_end_token_id
counts = (input_ids == end_token_id).any(dim=1).long() # [N]
N = counts.size(0)
device = input_ids.device
# Total number of videos in the batch
num_videos = int(counts.sum().item())
# Sanity check
assert video_grids.size(0) == num_videos, (
f"Expected {num_videos} videos, but got {video_grids.size(0)}"
)
video_num_frames = video_grids[:, 0] # [num_videos]
num_pooled_patches_per_video = video_grids.prod(dim=1) # [num_videos]
# pixel_values_videos: [n_frames, n_patches, pixels_per_patch]
n_frames, n_patches, pixels_per_patch = pixel_values_videos.shape
# 2) Map each video index -> example index
# Example: if counts = [2, 1, 3], then this becomes [0,0,1,2,2,2]
example_ids_for_video = torch.arange(N, device=device).repeat_interleave(
counts
) # [num_videos]
assert example_ids_for_video.numel() == num_videos
# 2-1) Compute frames_per_example by summing per-video frame counts
frames_per_example = torch.zeros(
N,
dtype=video_num_frames.dtype,
device=device,
)
frames_per_example.index_add_(0, example_ids_for_video, video_num_frames) # [N]
# 2-2) Compute num_pooled_patches_per_example
num_pooled_patches_per_example = torch.zeros(
N,
dtype=num_pooled_patches_per_video.dtype,
device=num_pooled_patches_per_video.device,
)
num_pooled_patches_per_example.index_add_(
0,
example_ids_for_video,
num_pooled_patches_per_video,
)
# Sanity checks
total_frames = int(frames_per_example.sum().item())
assert total_frames == n_frames, (
f"Expected {total_frames} frames, but got {n_frames}"
)
total_num_pooled_patches = int(num_pooled_patches_per_example.sum().item())
assert total_num_pooled_patches == video_token_pooling.size(0), (
f"Expected {total_num_pooled_patches} pooled patches, but got {video_token_pooling.size(0)}"
)
# 3) Build videos tensor filled with -1
M = int(frames_per_example.max().item())
videos = torch.full(
(N, M, n_patches, pixels_per_patch),
fill_value=-1,
dtype=pixel_values_videos.dtype,
device=device,
)
# 4) Fill videos with per-examples slices from pixel_values_videos
offset_frame = 0
for i in range(N):
num = int(frames_per_example[i].item())
cur = pixel_values_videos[
offset_frame : offset_frame + num
] # [num, n_patches, pixels_per_patch]
videos[i, :num] = cur
offset_frame += num
# Sanity check
assert offset_frame == n_frames
# 5) Build new token_pooling tensor filled with -1
P = int(num_pooled_patches_per_example.max().item())
_, dim = video_token_pooling.shape
new_token_pooling = torch.full(
(N, P, dim),
fill_value=-1,
dtype=video_token_pooling.dtype,
device=video_token_pooling.device,
)
# 6) Fill new token_pooling with per-examples slices from video_token_pooling
patch_offset = 0
for i in range(N):
num_patches = int(num_pooled_patches_per_example[i].item())
cur = video_token_pooling[
patch_offset : patch_offset + num_patches
] # [num_patches, dim]
new_token_pooling[i, :num_patches] = cur
patch_offset += num_patches
# Final sanity checks
assert patch_offset == total_num_pooled_patches
return videos, new_token_pooling
def merge_visual_inputs(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_token_pooling: Optional[torch.Tensor] = None,
image_grids: Optional[torch.Tensor] = None,
image_num_crops: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
video_token_pooling: Optional[torch.Tensor] = None,
video_grids: Optional[torch.Tensor] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if pixel_values is not None and pixel_values_videos is not None:
raise ValueError(
"pixel_values and pixel_values_videos are provided at the same time"
)
elif pixel_values is not None:
assert input_ids is not None
images, token_pooling = self.build_batched_images(
input_ids=input_ids,
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
)
elif pixel_values_videos is not None:
assert input_ids is not None
images, token_pooling = self.build_batched_videos(
input_ids=input_ids,
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
else:
images, token_pooling = None, None
return images, token_pooling
def build_input_embeddings(
self,
input_ids: torch.LongTensor,
images: Optional[torch.FloatTensor] = None, # image inputs
token_pooling: Optional[torch.LongTensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
x = self.transformer.wte(input_ids)
image_features: Optional[torch.FloatTensor] = None
if images is not None:
image_features = self.vision_backbone(images, token_pooling).to(x.device)
is_image_patch = input_ids.view(-1) == self.config.image_patch_id
assert is_image_patch.sum() == len(image_features)
x.view(-1, x.shape[-1])[is_image_patch] += image_features
# shape: (batch_size, seq_len, d_model)
x = self.transformer.emb_drop(x) # type: ignore
return x, image_features
def _build_native_attention_bias(
self,
*,
inputs_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor],
token_type_ids: Optional[torch.Tensor],
past_key_values: Optional[Cache],
) -> torch.Tensor:
if attention_mask is not None and attention_mask.ndim == 4:
return attention_mask.to(device=inputs_embeds.device)
batch_size, seq_len = inputs_embeds.shape[:2]
past_length = _cache_seq_len_int(past_key_values)
current_length = past_length + int(seq_len)
max_cache_len = _cache_max_len_int(past_key_values)
attention_mask_len = max_cache_len if max_cache_len > 0 else current_length
device = inputs_embeds.device
if attention_mask is None:
positions = torch.arange(attention_mask_len, device=device)
valid_mask = positions.unsqueeze(0) < current_length
valid_mask = valid_mask.expand(batch_size, -1)
elif attention_mask.ndim == 2:
valid_mask = torch.zeros(
(batch_size, attention_mask_len), device=device, dtype=torch.bool
)
source_mask = attention_mask.to(device=device, dtype=torch.bool)
copy_len = min(int(source_mask.shape[-1]), attention_mask_len)
if copy_len > 0:
valid_mask[:, :copy_len] = source_mask[:, :copy_len]
if attention_mask_len > current_length:
valid_mask[:, current_length:] = False
else:
raise ValueError(
f"Unsupported attention_mask shape for MolmoAct2: {tuple(attention_mask.shape)}"
)
valid_mask = valid_mask[:, None, None, :]
causal_mask = torch.tril(
torch.ones(
attention_mask_len, attention_mask_len, device=device, dtype=torch.bool
)
)[None, None, past_length:current_length, :attention_mask_len]
if token_type_ids is not None and past_length == 0:
image_mask = token_type_ids.to(device=device, dtype=torch.bool)
can_attend_back = image_mask[:, :, None] & image_mask[:, None, :]
image_len = min(int(token_type_ids.shape[1]), attention_mask_len)
causal_mask[:, :, :, :image_len] = (
causal_mask[:, :, :, :image_len]
| can_attend_back[:, None, :, :image_len]
)
allowed = valid_mask & causal_mask
return torch.where(
allowed,
torch.zeros((), device=device, dtype=inputs_embeds.dtype),
torch.full(
(),
torch.finfo(inputs_embeds.dtype).min,
device=device,
dtype=inputs_embeds.dtype,
),
)
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_token_pooling: Optional[torch.Tensor] = None,
image_grids: Optional[torch.Tensor] = None,
image_num_crops: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
video_token_pooling: Optional[torch.Tensor] = None,
video_grids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, MolmoAct2ModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
images, token_pooling = self.merge_visual_inputs(
input_ids=input_ids,
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
if images is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both images and inputs_embeds at the same time."
)
if inputs_embeds is None:
inputs_embeds, image_features = self.build_input_embeddings(
input_ids,
images,
token_pooling,
)
if cache_position is None:
past_seen_tokens = _cache_seq_len_int(past_key_values)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if isinstance(attention_mask, dict):
causal_mask_mapping = attention_mask
else:
causal_mask_mapping = self._build_native_attention_bias(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
past_key_values=past_key_values,
)
outputs = self.transformer(
attention_mask=causal_mask_mapping,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
return MolmoAct2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if images is not None else None,
)
class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {}
_tied_weights_keys = [] # Weights are not tied
# Reference: fix gemma3 grad acc #37208
accepts_loss_kwargs = False
config: MolmoAct2Config
def __init__(self, config: MolmoAct2Config):
super().__init__(config)
self.model = MolmoAct2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.vocab_size = config.vocab_size
self.model.action_cuda_graph_manager = ActionCudaGraphManager(self.model)
self.depth_decode_cuda_graph_manager = DepthDecodeCudaGraphManager(self)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> torch.nn.Module:
return self.model.transformer.wte
def set_input_embeddings(self, value: torch.nn.Module) -> None:
self.model.transformer.wte = value
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
# Make modules available throught conditional class for BC
@property
def language_model(self) -> torch.nn.Module:
return self.model.transformer
@property
def vision_backbone(self) -> torch.nn.Module:
return self.model.vision_backbone
def _get_robot_stats(self) -> _RobotStats:
stats = getattr(self, "_molmoact2_robot_stats", None)
if stats is not None:
return stats
filename = getattr(self.config, "norm_stats_filename", "norm_stats.json")
base_dir = getattr(self.config, "_name_or_path", None) or getattr(
self, "name_or_path", None
)
if not base_dir:
raise ValueError(
"MolmoAct2 normalization stats are not loaded and config._name_or_path is empty; "
"load the model from a converted HF directory containing norm_stats.json."
)
stats_path = os.path.join(str(base_dir), filename)
if not os.path.isfile(stats_path):
raise FileNotFoundError(
f"MolmoAct2 normalization stats file is missing: {stats_path}. "
"Converted checkpoints must include norm_stats.json."
)
with open(stats_path, "r", encoding="utf-8") as f:
payload = json.load(f)
stats = _RobotStats(payload)
self._molmoact2_robot_stats = stats
return stats
@staticmethod
def _move_inputs_to_device(
inputs: Mapping[str, Any], device: torch.device
) -> Dict[str, Any]:
out = {}
for key, value in inputs.items():
out[key] = value.to(device) if torch.is_tensor(value) else value
return out
@staticmethod
def _drop_trivial_attention_mask(inputs: Mapping[str, Any]) -> Dict[str, Any]:
out = dict(inputs)
attention_mask = out.get("attention_mask")
if torch.is_tensor(attention_mask) and bool(
attention_mask.to(dtype=torch.bool).all().item()
):
out.pop("attention_mask", None)
return out
@staticmethod
def _count_images(images: Any) -> int:
if images is None:
return 0
if isinstance(images, (list, tuple)):
return len(images)
arr = np.asarray(images) if not torch.is_tensor(images) else images
if getattr(arr, "ndim", 0) == 4:
return int(arr.shape[0])
return 1
@staticmethod
def _build_action_dim_is_pad(
*,
action_dim: int,
max_action_dim: int,
batch_size: int,
device: torch.device,
) -> Optional[torch.Tensor]:
if int(action_dim) > int(max_action_dim):
raise ValueError(
f"Requested action_dim {int(action_dim)} exceeds checkpoint max_action_dim {int(max_action_dim)}."
)
if int(action_dim) == int(max_action_dim):
return None
mask = torch.ones(
(int(batch_size), int(max_action_dim)), device=device, dtype=torch.bool
)
mask[:, : int(action_dim)] = False
return mask
@staticmethod
def _slice_action_dim(actions: torch.Tensor, action_dim: int) -> torch.Tensor:
if actions.shape[-1] < int(action_dim):
raise ValueError(
f"Requested action_dim {int(action_dim)} but chunk only has width {actions.shape[-1]}."
)
return actions[..., : int(action_dim)]
@staticmethod
def _slice_action_chunk(
actions: torch.Tensor, n_obs_steps: int, n_action_steps: Optional[int]
) -> torch.Tensor:
if n_action_steps is None:
return actions
start = int(n_obs_steps) - 1
end = start + int(n_action_steps)
if end > actions.shape[1]:
raise ValueError(
f"Requested actions up to {end} but model produced horizon {actions.shape[1]}."
)
return actions[:, start:end]
def _depth_token_id_to_bin(self) -> Dict[int, int]:
if (
self.config.depth_token_start_id is None
or int(self.config.num_depth_tokens or 0) <= 0
):
return {}
start = int(self.config.depth_token_start_id)
return {start + idx: idx for idx in range(int(self.config.num_depth_tokens))}
def _action_token_id_to_bin(self) -> Dict[int, int]:
if (
self.config.action_token_start_id is None
or int(self.config.num_action_tokens or 0) <= 0
):
return {}
start = int(self.config.action_token_start_id)
return {start + idx: idx for idx in range(int(self.config.num_action_tokens))}
def _require_eos_token_id(self) -> int:
eos_token_id = getattr(self.config, "eos_token_id", None)
if (
eos_token_id is None
and getattr(self, "generation_config", None) is not None
):
eos_token_id = getattr(self.generation_config, "eos_token_id", None)
if isinstance(eos_token_id, (list, tuple)):
eos_token_id = eos_token_id[0] if eos_token_id else None
if eos_token_id is None:
raise RuntimeError(
"Discrete action generation requires `eos_token_id` in the converted HF config."
)
return int(eos_token_id)
def _decode_depth_bins_from_token_ids(
self, token_ids: torch.Tensor
) -> torch.Tensor:
if (
self.config.depth_start_token_id is None
or self.config.depth_end_token_id is None
):
raise RuntimeError(
"Depth generation requires <depth_start>/<depth_end> token IDs."
)
token_id_to_bin = self._depth_token_id_to_bin()
if not token_id_to_bin:
raise RuntimeError(
"Depth generation requires indexed depth tokens in the converted config."
)
depth_token_bins = _extract_discrete_token_bins(
_flatten_generated_token_ids(token_ids),
int(self.config.depth_start_token_id),
int(self.config.depth_end_token_id),
token_id_to_bin,
)
if not depth_token_bins:
raise RuntimeError(
"Model generated no decodable depth tokens between <depth_start>/<depth_end>."
)
return torch.as_tensor([depth_token_bins], device=self.device, dtype=torch.long)
def _consume_generation_tokens(
self,
token_ids: torch.Tensor,
*,
past_key_values: Optional[Cache],
attention_mask: Optional[torch.Tensor],
) -> Tuple[MolmoAct2CausalLMOutputWithPast, Optional[torch.Tensor]]:
if token_ids.ndim == 1:
next_input_ids = token_ids.unsqueeze(1)
elif token_ids.ndim == 2:
next_input_ids = token_ids
else:
raise ValueError(
f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}."
)
next_attention_mask = attention_mask
if next_attention_mask is not None:
past_length = _cache_seq_len_int(past_key_values)
required_len = int(past_length) + int(next_input_ids.shape[1])
if int(next_attention_mask.shape[-1]) < required_len:
pad_len = required_len - int(next_attention_mask.shape[-1])
next_attention_mask = torch.cat(
(
next_attention_mask,
next_attention_mask.new_ones(
(next_input_ids.shape[0], pad_len)
),
),
dim=-1,
)
past_length = _cache_seq_len_int(past_key_values)
output = self(
input_ids=next_input_ids,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
use_cache=True,
cache_position=(
torch.arange(
past_length,
past_length + int(next_input_ids.shape[1]),
device=next_input_ids.device,
)
if past_key_values is not None
else None
),
)
return output, next_attention_mask
def _make_depth_decode_attention_bias(
self, inputs: Mapping[str, Any], past_key_values: Cache
) -> torch.Tensor:
layers = getattr(past_key_values, "layers", None)
max_cache_len = int(getattr(layers[0], "max_cache_len", 0)) if layers else 0
if max_cache_len <= 0:
raise RuntimeError(
"Depth decode fast path requires a cache with a fixed maximum length."
)
input_ids = inputs["input_ids"]
batch_size = int(input_ids.shape[0])
device = input_ids.device
dtype = self.lm_head.weight.dtype
positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
valid_mask = torch.ones(
(batch_size, max_cache_len), device=device, dtype=torch.bool
)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
source_mask = attention_mask.to(device=device, dtype=torch.bool)
copy_len = min(int(source_mask.shape[-1]), max_cache_len)
if copy_len > 0:
valid_mask[:, :copy_len] = source_mask[:, :copy_len]
causal_mask = positions[None, :] <= positions[:, None]
allowed = causal_mask.unsqueeze(0) & valid_mask[:, None, :]
attention_bias = torch.where(
allowed[:, None, :, :],
torch.zeros((), device=device, dtype=dtype),
torch.full((), torch.finfo(dtype).min, device=device, dtype=dtype),
)
return attention_bias
def _embed_base_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
# Skips MolmoAct2Embedding's per-call cat([base, new]); safe only for IDs
# below text_config.vocab_size, which is the case for all depth tokens.
wte = self.model.transformer.wte
base_embedding = getattr(wte, "embedding", None)
if base_embedding is None:
return wte(input_ids)
return F.embedding(input_ids, base_embedding)
def _run_depth_decode_step(
self,
token_ids: torch.Tensor,
*,
past_key_values: Cache,
attention_bias: torch.Tensor,
) -> Tuple[torch.Tensor, Cache]:
if token_ids.ndim == 1:
next_input_ids = token_ids.unsqueeze(1)
elif token_ids.ndim == 2:
next_input_ids = token_ids
else:
raise ValueError(
f"Expected token_ids to have rank 1 or 2, got {tuple(token_ids.shape)}."
)
past_length = _cache_seq_len_int(past_key_values)
end = past_length + int(next_input_ids.shape[1])
if self.depth_decode_cuda_graph_manager.can_use(
next_input_ids,
past_key_values=past_key_values,
attention_bias=attention_bias,
):
return self.depth_decode_cuda_graph_manager.run(
next_input_ids,
past_key_values=past_key_values,
attention_bias=attention_bias,
past_length=past_length,
)
cache_position = torch.arange(
past_length, end, device=next_input_ids.device, dtype=torch.long
)
attention_bias = attention_bias[:, :, past_length:end, :end]
inputs_embeds = self._embed_base_tokens(next_input_ids)
outputs = self.model.transformer(
attention_mask=attention_bias,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=True,
output_attentions=False,
output_hidden_states=False,
cache_position=cache_position,
)
return outputs.last_hidden_state[:, -1:, :], outputs.past_key_values
def _project_depth_logits(self, last_hidden: torch.Tensor) -> torch.Tensor:
start = int(self.config.depth_token_start_id)
end_id = start + int(self.config.num_depth_tokens)
return F.linear(last_hidden, self.lm_head.weight[start:end_id])
def _max_depth_decode_steps(self) -> int:
return max(
int(self.config.num_depth_codes or 0) + 8,
int(self.config.action_horizon or 1) * 16,
1,
)
def _make_depth_static_cache(self, inputs: Mapping[str, Any]) -> Cache:
prompt_len = inputs["input_ids"].shape[1]
action_horizon = int(self.config.action_horizon or 1)
max_end_steps = max(8, action_horizon)
action_token_budget = max(1, action_horizon * 16)
return self.depth_decode_cuda_graph_manager.make_static_cache(
max_cache_len=prompt_len
+ self._max_depth_decode_steps()
+ max_end_steps
+ action_token_budget,
)
def _continue_discrete_generation_from_output(
self,
initial_output: MolmoAct2CausalLMOutputWithPast,
*,
past_key_values: Optional[Cache],
attention_mask: Optional[torch.Tensor],
end_token_id: int,
max_steps: int,
) -> torch.Tensor:
generated_tokens: List[torch.Tensor] = []
current_output = initial_output
current_past_key_values = past_key_values
current_attention_mask = attention_mask
hit_end = False
for _ in range(int(max_steps)):
next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1)
generated_tokens.append(next_token)
if bool((next_token == int(end_token_id)).all()):
hit_end = True
break
current_output, current_attention_mask = self._consume_generation_tokens(
next_token,
past_key_values=current_past_key_values,
attention_mask=current_attention_mask,
)
current_past_key_values = current_output.past_key_values
if not generated_tokens:
raise RuntimeError("Discrete continuation generated no tokens.")
if not hit_end:
raise RuntimeError(
f"Discrete continuation did not emit end token {int(end_token_id)} within {int(max_steps)} steps."
)
return torch.stack(generated_tokens, dim=1)
def _generate_depth_prefix(
self,
inputs: Mapping[str, Any],
*,
latest_first_image: Optional[np.ndarray],
depth_cache: Optional[Mapping[str, Any]],
enable_adaptive_depth: bool,
) -> _DepthPrefix:
if (
self.config.depth_start_token_id is None
or self.config.depth_end_token_id is None
):
raise RuntimeError(
"Depth reasoning requires single-token <depth_start>/<depth_end>."
)
if (
self.config.depth_token_start_id is None
or int(self.config.num_depth_tokens or 0) <= 0
):
raise RuntimeError("Depth reasoning requires indexed depth tokens.")
batch_size = int(inputs["input_ids"].shape[0])
if batch_size != 1 and enable_adaptive_depth:
raise ValueError("enable_adaptive_depth=True currently supports batch size 1.")
static_cache = self._make_depth_static_cache(inputs)
output = self(**inputs, use_cache=True, past_key_values=static_cache)
current_output = output
current_past_key_values = output.past_key_values
current_attention_mask = inputs.get("attention_mask")
generated_tokens: List[torch.Tensor] = []
if not enable_adaptive_depth:
hit_depth_end = False
max_steps = self._max_depth_decode_steps()
for _ in range(max_steps):
next_token = torch.argmax(current_output.logits[:, -1, :], dim=-1)
generated_tokens.append(next_token)
current_output, current_attention_mask = (
self._consume_generation_tokens(
next_token,
past_key_values=current_past_key_values,
attention_mask=current_attention_mask,
)
)
current_past_key_values = current_output.past_key_values
if bool((next_token == int(self.config.depth_end_token_id)).all()):
hit_depth_end = True
break
if not generated_tokens:
raise RuntimeError("Depth generation produced no tokens.")
if not hit_depth_end:
raise RuntimeError(
f"Depth generation did not emit <depth_end> within {max_steps} steps."
)
depth_token_ids = torch.stack(generated_tokens, dim=1)
full_input_ids = torch.cat([inputs["input_ids"], depth_token_ids], dim=1)
full_attention_mask = None
if current_attention_mask is not None:
full_attention_mask = current_attention_mask[
:, : full_input_ids.shape[1]
]
encoder_kv_states = self.model._extract_kv_states(current_past_key_values)
return _DepthPrefix(
token_ids=depth_token_ids,
depth_bins=self._decode_depth_bins_from_token_ids(depth_token_ids),
full_input_ids=full_input_ids,
attention_mask=full_attention_mask,
encoder_kv_states=encoder_kv_states,
next_output=current_output,
past_key_values=current_past_key_values,
)
depth_start = torch.full(
(batch_size,),
int(self.config.depth_start_token_id),
device=self.device,
dtype=torch.long,
)
code_token_ids = torch.arange(
int(self.config.depth_token_start_id),
int(self.config.depth_token_start_id) + int(self.config.num_depth_tokens),
device=self.device,
dtype=torch.long,
)
depth_attention_bias = self._make_depth_decode_attention_bias(
inputs, current_past_key_values
)
generated_tokens.append(depth_start)
last_hidden, current_past_key_values = self._run_depth_decode_step(
depth_start,
past_key_values=current_past_key_values,
attention_bias=depth_attention_bias,
)
previous_image = None
previous_bins = None
if depth_cache is not None:
previous_image = depth_cache.get("image")
previous_bins = depth_cache.get("depth_bins")
selective = (
bool(enable_adaptive_depth)
and latest_first_image is not None
and previous_image is not None
and previous_bins is not None
)
update_mask = None
previous_buffer_t = None
if selective:
previous_buffer = np.asarray(previous_bins, dtype=np.int64).reshape(-1)
if previous_buffer.shape[0] == int(self.config.num_depth_codes):
update_mask = _compute_depth_update_mask(
latest_first_image,
_normalize_image_for_cache(previous_image),
num_depth_codes=int(self.config.num_depth_codes),
)
previous_buffer_t = (
torch.from_numpy(previous_buffer)
.to(
device=self.device,
dtype=torch.long,
)
.unsqueeze(0)
)
else:
selective = False
depth_bins = torch.zeros(
(batch_size, int(self.config.num_depth_codes)),
device=self.device,
dtype=torch.long,
)
num_depth_codes = int(self.config.num_depth_codes)
if not selective or update_mask is None or previous_buffer_t is None:
for depth_idx in range(num_depth_codes):
depth_logits = self._project_depth_logits(last_hidden)
predicted_bins = depth_logits.squeeze(1).argmax(dim=-1)
depth_bins[:, depth_idx] = predicted_bins
chosen_token_ids = code_token_ids[predicted_bins]
generated_tokens.append(chosen_token_ids)
last_hidden, current_past_key_values = self._run_depth_decode_step(
chosen_token_ids,
past_key_values=current_past_key_values,
attention_bias=depth_attention_bias,
)
else:
for start_idx, end_idx, should_generate in _build_depth_update_spans(
update_mask
):
if should_generate:
for depth_idx in range(start_idx, end_idx):
depth_logits = self._project_depth_logits(last_hidden)
predicted_bins = depth_logits.squeeze(1).argmax(dim=-1)
depth_bins[:, depth_idx] = predicted_bins
chosen_token_ids = code_token_ids[predicted_bins]
generated_tokens.append(chosen_token_ids)
last_hidden, current_past_key_values = (
self._run_depth_decode_step(
chosen_token_ids,
past_key_values=current_past_key_values,
attention_bias=depth_attention_bias,
)
)
continue
replay_bins = previous_buffer_t[:, start_idx:end_idx].expand(
batch_size, -1
)
depth_bins[:, start_idx:end_idx] = replay_bins
replay_token_ids = code_token_ids[replay_bins]
generated_tokens.extend(replay_token_ids.unbind(dim=1))
last_hidden, current_past_key_values = self._run_depth_decode_step(
replay_token_ids,
past_key_values=current_past_key_values,
attention_bias=depth_attention_bias,
)
hit_depth_end = False
max_depth_end_steps = max(8, int(self.config.action_horizon or 1))
full_logits = self.lm_head(last_hidden)
for _ in range(max_depth_end_steps):
next_token = full_logits.squeeze(1).argmax(dim=-1)
generated_tokens.append(next_token)
last_hidden, current_past_key_values = self._run_depth_decode_step(
next_token,
past_key_values=current_past_key_values,
attention_bias=depth_attention_bias,
)
full_logits = self.lm_head(last_hidden)
if bool((next_token == int(self.config.depth_end_token_id)).all()):
hit_depth_end = True
break
if not hit_depth_end:
raise RuntimeError(
f"Depth generation did not emit <depth_end> within {max_depth_end_steps} steps "
"after adaptive depth tokens."
)
depth_token_ids = torch.stack(generated_tokens, dim=1)
full_input_ids = torch.cat([inputs["input_ids"], depth_token_ids], dim=1)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
full_attention_mask = torch.cat(
(attention_mask, attention_mask.new_ones(depth_token_ids.shape)),
dim=-1,
)[:, : full_input_ids.shape[1]]
else:
full_attention_mask = None
current_output = MolmoAct2CausalLMOutputWithPast(
logits=full_logits,
past_key_values=current_past_key_values,
)
encoder_kv_states = self.model._extract_kv_states(current_past_key_values)
return _DepthPrefix(
token_ids=depth_token_ids,
depth_bins=depth_bins,
full_input_ids=full_input_ids,
attention_mask=full_attention_mask,
encoder_kv_states=encoder_kv_states,
next_output=current_output,
past_key_values=current_past_key_values,
)
def _decode_discrete_action_chunk(
self,
generated_token_ids: torch.Tensor,
*,
action_tokenizer: Any,
action_dim: int,
) -> torch.Tensor:
if action_tokenizer is None:
raise ValueError(
"action_mode='discrete' requires an `action_tokenizer` input."
)
if (
self.config.action_start_token_id is None
or self.config.action_end_token_id is None
):
raise RuntimeError(
"Discrete action generation requires <action_start>/<action_end> token IDs."
)
token_id_to_bin = self._action_token_id_to_bin()
if not token_id_to_bin:
raise RuntimeError(
"Discrete action generation requires indexed action tokens in the converted config."
)
discrete_token_ids = _extract_discrete_token_bins(
_flatten_generated_token_ids(generated_token_ids),
int(self.config.action_start_token_id),
int(self.config.action_end_token_id),
token_id_to_bin,
)
if not discrete_token_ids:
raise RuntimeError(
"Model generated no decodable action tokens between <action_start>/<action_end>."
)
try:
decoded = action_tokenizer.decode(
[discrete_token_ids],
time_horizon=int(self.config.action_horizon),
action_dim=int(action_dim),
)
except TypeError:
decoded = action_tokenizer.decode([discrete_token_ids])
action_chunk = np.asarray(decoded, dtype=np.float32)
if action_chunk.ndim == 1:
action_chunk = action_chunk[None, None, :]
elif action_chunk.ndim == 2:
action_chunk = action_chunk[None, :, :]
elif action_chunk.ndim > 3:
action_chunk = action_chunk.reshape(
1, action_chunk.shape[-2], action_chunk.shape[-1]
)
if action_chunk.ndim != 3:
raise RuntimeError(
f"Decoded action chunk has unexpected shape {action_chunk.shape}."
)
return torch.as_tensor(action_chunk, device=self.device, dtype=torch.float32)
@torch.no_grad()
def predict_action(
self,
*,
processor: Any,
images: Any,
task: str,
state: Any,
norm_tag: str,
action_mode: str = "continuous",
enable_depth_reasoning: bool = False,
enable_adaptive_depth: bool = True,
depth_cache: Optional[Mapping[str, Any]] = None,
action_tokenizer: Any = None,
num_steps: Optional[int] = None,
n_action_steps: Optional[int] = None,
generator: Optional[torch.Generator] = None,
normalize_language: bool = True,
enable_cuda_graph: bool = True,
return_dict: bool = True,
) -> Union[MolmoAct2ActionOutput, torch.Tensor]:
if not bool(self.config.add_action_expert):
raise RuntimeError(
"This MolmoAct2 checkpoint was converted with add_action_expert=False; "
"use standard Transformers generation for VLM inference."
)
if state is None:
raise ValueError(
"MolmoAct2 `predict_action` requires `state` for discrete state prompting."
)
action_mode = str(action_mode or "continuous")
if action_mode not in {"continuous", "discrete"}:
raise ValueError("action_mode must be either 'continuous' or 'discrete'.")
if action_mode == "continuous" and self.config.action_format not in {
"continuous",
"both",
}:
raise ValueError(
f"action_mode='continuous' requires checkpoint action_format in {{'continuous', 'both'}}, "
f"got {self.config.action_format!r}."
)
if action_mode == "discrete":
if action_tokenizer is None:
raise ValueError(
"action_mode='discrete' requires an `action_tokenizer` input."
)
if self.config.action_format not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_format in {{'discrete', 'both'}}, "
f"got {self.config.action_format!r}."
)
if enable_depth_reasoning and not bool(self.config.enable_depth_reasoning):
raise ValueError(
"this model was not trained with `--enable_depth_reasoning`."
)
stats = self._get_robot_stats()
norm_tag = stats.validate_tag(norm_tag)
metadata = stats.get_metadata(norm_tag)
normalized_state = np.asarray(
stats.normalize_state(state, norm_tag), dtype=np.float32
)
num_state_tokens = int(self.config.num_state_tokens or 0)
if num_state_tokens <= 0:
raise RuntimeError(
"Discrete state prompting requires indexed state tokens in the converted config."
)
discrete_state_string = _build_discrete_state_string(
normalized_state, num_state_tokens
)
style = "robot_depth_action" if enable_depth_reasoning else "robot_action"
task_text = str(task or "")
if normalize_language:
task_text = _normalize_question_text(task_text)
text = _build_robot_text(
task=task_text,
style=style,
discrete_state_string=discrete_state_string,
setup_type=str(metadata.get("setup_type", "") or ""),
control_mode=str(metadata.get("control_mode", "") or ""),
add_setup_tokens=bool(self.config.add_setup_tokens),
add_control_tokens=bool(self.config.add_control_tokens),
num_images=self._count_images(images),
)
inputs = processor(text=text, images=images, return_tensors="pt")
inputs = self._move_inputs_to_device(inputs, self.device)
inputs = self._drop_trivial_attention_mask(inputs)
action_dim = stats.get_action_dim(norm_tag)
if action_dim is None:
action_dim = int(self.config.max_action_dim)
action_dim = int(action_dim)
action_horizon = stats.get_action_horizon(norm_tag) or int(
self.config.action_horizon
)
if int(action_horizon) > int(self.config.action_horizon):
raise ValueError(
f"Tag action_horizon={int(action_horizon)} exceeds checkpoint action_horizon={int(self.config.action_horizon)}."
)
resolved_n_action_steps = n_action_steps
if resolved_n_action_steps is None:
resolved_n_action_steps = stats.get_n_action_steps(norm_tag)
if resolved_n_action_steps is None:
resolved_n_action_steps = int(action_horizon)
resolved_n_action_steps = int(resolved_n_action_steps)
if resolved_n_action_steps < 1:
raise ValueError(
f"n_action_steps must be >= 1, got {resolved_n_action_steps}."
)
if resolved_n_action_steps > int(action_horizon):
raise ValueError(
f"Requested n_action_steps={resolved_n_action_steps} exceeds tag action_horizon={int(action_horizon)}."
)
batch_size = int(inputs["input_ids"].shape[0])
action_dim_is_pad = self._build_action_dim_is_pad(
action_dim=action_dim,
max_action_dim=int(self.config.max_action_dim),
batch_size=batch_size,
device=self.device,
)
self.model.action_cuda_graph_manager.set_enabled(enable_cuda_graph)
self.depth_decode_cuda_graph_manager.set_enabled(enable_cuda_graph)
generated_token_ids = None
depth_bins = None
updated_depth_cache = depth_cache
if action_mode == "continuous":
if enable_depth_reasoning:
latest_first_image = _extract_first_image(images)
depth_prefix = self._generate_depth_prefix(
inputs,
latest_first_image=latest_first_image,
depth_cache=depth_cache,
enable_adaptive_depth=bool(enable_adaptive_depth),
)
generated_token_ids = depth_prefix.token_ids
depth_bins = depth_prefix.depth_bins
actions = self.model.generate_actions_from_inputs(
input_ids=depth_prefix.full_input_ids,
attention_mask=depth_prefix.attention_mask,
action_dim_is_pad=action_dim_is_pad,
num_steps=num_steps,
generator=generator,
encoder_kv_states=depth_prefix.encoder_kv_states,
encoder_attention_mask=self.model._get_encoder_attention_mask(
depth_prefix.full_input_ids,
depth_prefix.attention_mask,
),
)
if latest_first_image is not None:
updated_depth_cache = {
"image": latest_first_image,
"depth_bins": depth_bins.detach()
.cpu()
.reshape(-1)
.numpy()
.astype(np.int64),
}
else:
actions = self.model.generate_actions_from_inputs(
**inputs,
action_dim_is_pad=action_dim_is_pad,
num_steps=num_steps,
generator=generator,
)
else:
if enable_depth_reasoning:
latest_first_image = _extract_first_image(images)
depth_prefix = self._generate_depth_prefix(
inputs,
latest_first_image=latest_first_image,
depth_cache=depth_cache,
enable_adaptive_depth=bool(enable_adaptive_depth),
)
action_token_ids = self._continue_discrete_generation_from_output(
depth_prefix.next_output,
past_key_values=depth_prefix.past_key_values,
attention_mask=depth_prefix.attention_mask,
end_token_id=self._require_eos_token_id(),
max_steps=max(1, int(self.config.action_horizon * 16)),
)
generated_token_ids = torch.cat(
[depth_prefix.token_ids, action_token_ids], dim=1
)
depth_bins = depth_prefix.depth_bins
if latest_first_image is not None:
updated_depth_cache = {
"image": latest_first_image,
"depth_bins": depth_bins.detach()
.cpu()
.reshape(-1)
.numpy()
.astype(np.int64),
}
else:
prefill_output = self(**inputs, use_cache=True)
action_token_ids = self._continue_discrete_generation_from_output(
prefill_output,
past_key_values=prefill_output.past_key_values,
attention_mask=inputs.get("attention_mask"),
end_token_id=self._require_eos_token_id(),
max_steps=max(1, int(self.config.action_horizon * 16)),
)
generated_token_ids = action_token_ids
actions = self._decode_discrete_action_chunk(
generated_token_ids,
action_tokenizer=action_tokenizer,
action_dim=action_dim,
)
actions = self._slice_action_dim(actions, action_dim)
actions = self._slice_action_chunk(
actions, int(self.config.n_obs_steps), resolved_n_action_steps
)
actions = stats.unnormalize_action(actions, norm_tag)
if not torch.is_tensor(actions):
actions = torch.as_tensor(actions, device=self.device, dtype=torch.float32)
else:
actions = actions.to(device=self.device, dtype=torch.float32)
output = MolmoAct2ActionOutput(
actions=actions,
generated_token_ids=generated_token_ids,
depth_bins=depth_bins,
depth_cache=updated_depth_cache,
)
if return_dict:
return output
return actions
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.Tensor] = None,
image_token_pooling: Optional[torch.Tensor] = None,
image_grids: Optional[torch.Tensor] = None,
image_num_crops: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
video_token_pooling: Optional[torch.Tensor] = None,
video_grids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, MolmoAct2CausalLMOutputWithPast]:
r"""
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, MolmoAct2ForConditionalGeneration
>>> model = MolmoAct2ForConditionalGeneration.from_pretrained("...")
>>> processor = AutoProcessor.from_pretrained("...")
>>> prompt = "What's the content of the image?"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> messages = [{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image", "image": image}]}]
>>> inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True)
>>> # Generate
>>> generated_ids = model.generate(**inputs, max_new_tokens=15)
>>> generated_tokens = generated_ids[:, inputs['input_ids'].size(1):]
>>> processor.post_process_image_text_to_text(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a bustling street scene in what appears to be a Chinatown area. There's ..."
```"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_token_pooling=image_token_pooling,
image_grids=image_grids,
image_num_crops=image_num_crops,
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.vocab_size
)
return MolmoAct2CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_token_pooling: Optional[torch.Tensor] = None,
image_grids: Optional[torch.Tensor] = None,
image_num_crops: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
video_token_pooling: Optional[torch.Tensor] = None,
video_grids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Optional[Union[int, torch.Tensor]] = None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
include_visual_inputs = past_key_values is None
if past_key_values is not None and hasattr(past_key_values, "get_seq_length"):
include_visual_inputs = int(past_key_values.get_seq_length()) == 0
if include_visual_inputs:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_token_pooling"] = image_token_pooling
model_inputs["image_grids"] = image_grids
model_inputs["image_num_crops"] = image_num_crops
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["video_token_pooling"] = video_token_pooling
model_inputs["video_grids"] = video_grids
return model_inputs
# Adapted from transformers.models.gemma3.modeling_gemma3
@staticmethod
def create_masks_for_generate(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# Prepare mask arguments
mask_kwargs = {
"config": config.get_text_config(),
"input_embeds": input_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
)
return create_masks_for_generate(**mask_kwargs)
# Always register for multi-modal features
AutoModelForImageTextToText.register(MolmoAct2Config, MolmoAct2ForConditionalGeneration)