from typing import Any
import torch
from torch import nn
from .base import BaseSampler
[docs]
class QAlignSampler(BaseSampler):
"""QAlign sampler that uses MCMC to improve model outputs based on a reward model.
This implementation is based on the paper:
"Sample, Don't Search: Rethinking Test-Time Alignment for Language Models"
by Faria et al. (2024)
Paper: https://arxiv.org/abs/2504.03790
QAlign uses Markov Chain Monte Carlo (MCMC) to align model outputs at test time
without requiring model fine-tuning. It converges to sampling from the optimal
aligned distribution as test-time compute scales.
"""
[docs]
def __init__(
self,
reward_model: nn.Module,
num_steps: int = 10,
temperature: float = 1.0,
beta: float = 1.0,
proposal_temp: float = 1.0,
):
"""Initialize QAlign sampler.
Args:
reward_model: The reward model to use for scoring outputs
num_steps: Number of MCMC steps to perform
temperature: Temperature for sampling from the base model
beta: Inverse temperature for the reward model (higher = more emphasis on reward)
proposal_temp: Temperature for the proposal distribution
"""
super().__init__()
if temperature <= 0:
raise ValueError("Temperature must be positive")
if proposal_temp <= 0:
raise ValueError("Proposal temperature must be positive")
self.reward_model = reward_model
self.num_steps = num_steps
self.temperature = temperature
self.beta = beta
self.proposal_temp = proposal_temp
[docs]
def sample(
self,
model: Any,
input_ids: torch.Tensor,
max_length: int = 100,
num_return_sequences: int = 1,
**kwargs,
) -> torch.Tensor:
"""Sample from the model using QAlign.
Args:
model: The language model to sample from
input_ids: Input token IDs
max_length: Maximum length of the generated sequence
num_return_sequences: Number of sequences to return
**kwargs: Additional arguments specific to the sampler
Returns:
torch.Tensor: Generated token IDs
"""
batch_size = input_ids.shape[0]
current_ids = input_ids.clone()
# Generate initial sequences
for _ in range(max_length - input_ids.shape[1]):
logits = self._get_logits(model, current_ids)
logits = logits / self.temperature
next_tokens = self._sample_from_logits(logits)
# Ensure next_tokens has the correct batch dimension
if next_tokens.shape[0] != batch_size:
next_tokens = next_tokens.expand(batch_size, -1)
current_ids = torch.cat([current_ids, next_tokens], dim=1)
# Stop if all sequences are complete
if (
hasattr(model.config, "eos_token_id")
and (next_tokens == model.config.eos_token_id).all()
):
break
# Perform MCMC steps
for _ in range(self.num_steps):
# Generate proposal
proposal_ids = self._generate_proposal(model, current_ids)
# Compute acceptance probability
current_reward = self._compute_reward(current_ids)
proposal_reward = self._compute_reward(proposal_ids)
# Metropolis-Hastings acceptance
acceptance_prob = torch.exp(
self.beta * (proposal_reward - current_reward)
).clamp(0, 1)
# Accept or reject
accept = torch.rand_like(acceptance_prob) < acceptance_prob
current_ids = torch.where(
accept.unsqueeze(-1).expand_as(current_ids), proposal_ids, current_ids
)
return current_ids
def _generate_proposal(self, model: Any, current_ids: torch.Tensor) -> torch.Tensor:
"""Generate a proposal sequence using the base model.
Args:
model: The language model
current_ids: Current sequence of token IDs
Returns:
torch.Tensor: Proposed sequence of token IDs
"""
batch_size = current_ids.shape[0]
seq_len = current_ids.shape[1]
# Randomly select a position to resample for each example in the batch
pos = torch.randint(0, seq_len, (batch_size,), device=self.device)
# Generate new tokens from the position onwards
proposal_ids = current_ids.clone()
# Process each batch item separately to handle different positions
for b in range(batch_size):
for i in range(pos[b], seq_len):
logits = self._get_logits(model, proposal_ids[b : b + 1, :i])
logits = logits / self.proposal_temp
next_token = self._sample_from_logits(logits)
proposal_ids[b, i] = next_token.squeeze(-1)
return proposal_ids
def _compute_reward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Compute reward for the given sequence using the reward model.
Args:
input_ids: Sequence of token IDs
Returns:
torch.Tensor: Reward scores
"""
with torch.no_grad():
reward = self.reward_model(input_ids)
return reward