Source code for llm_samplers.xtc

from typing import Any

import torch

from .base import BaseSampler


[docs] class XTCSampler(BaseSampler): """ XTC (Exclude Top Choices) sampling "turns truncation on its head." Instead of pruning low-probability tokens, XTC targets the most probable ones under certain conditions to enhance creativity. """
[docs] def __init__( self, top_k: int = 5, exclusion_threshold: float = 0.3, min_probability: float = 0.1, ): """ Initialize the XTC sampler. Args: top_k: Number of top tokens to potentially exclude exclusion_threshold: Probability threshold for exclusion min_probability: Minimum probability to consider for exclusion """ super().__init__() if top_k <= 0: raise ValueError("top_k must be positive") if not 0 < exclusion_threshold < 1: raise ValueError("exclusion_threshold must be in (0, 1)") if not 0 < min_probability < 1: raise ValueError("min_probability must be in (0, 1)") self.top_k = top_k self.exclusion_threshold = exclusion_threshold self.min_probability = min_probability
def _apply_sampling(self, logits: torch.Tensor) -> torch.Tensor: """ Apply XTC filtering to the logits. Args: logits: Raw logits from the model Returns: torch.Tensor: XTC filtered logits """ # Convert logits to probabilities probs = torch.softmax(logits, dim=-1) # Get top-k probabilities and indices top_k_probs, top_k_indices = torch.topk(probs, self.top_k, dim=-1) # Create a mask for tokens to exclude # Exclude tokens that are both in top-k and above the exclusion threshold exclusion_mask = torch.zeros_like(probs, dtype=torch.bool) for i in range(self.top_k): mask = (top_k_probs[:, i] > self.exclusion_threshold) & ( top_k_probs[:, i] > self.min_probability ) exclusion_mask.scatter_(1, top_k_indices[:, i : i + 1], mask.unsqueeze(1)) # Set excluded tokens to negative infinity filtered_logits = logits.clone() filtered_logits[exclusion_mask] = float("-inf") return filtered_logits
[docs] def sample( self, model: Any, input_ids: torch.Tensor, max_length: int = 100, num_return_sequences: int = 1, **kwargs, ) -> torch.Tensor: """ Generate text using XTC sampling. Args: model: The language model to sample from input_ids: Input token IDs max_length: Maximum length of the generated sequence num_return_sequences: Number of sequences to return **kwargs: Additional arguments Returns: torch.Tensor: Generated token IDs """ generated = input_ids.clone() for _ in range(max_length): logits = self._get_logits(model, generated) logits = self._apply_sampling(logits) next_tokens = self._sample_from_logits(logits, num_samples=1) generated = torch.cat([generated, next_tokens], dim=1) # Check if all sequences have generated an EOS token if (next_tokens == model.config.eos_token_id).any(): break return generated