Source code for llm_samplers.top_k

from typing import Any

import torch

from .base import BaseSampler


[docs] class TopKSampler(BaseSampler): """ Top-K sampling considers only the 'k' most probable tokens. All other tokens are set to have zero probability. """
[docs] def __init__(self, k: int = 50): """ Initialize the Top-K sampler. Args: k: Number of top tokens to consider """ super().__init__() if k <= 0: raise ValueError("k must be positive") self.k = k
def _apply_sampling(self, logits: torch.Tensor) -> torch.Tensor: """ Apply Top-K filtering to the logits. Args: logits: Raw logits from the model Returns: torch.Tensor: Top-K filtered logits """ vocab_size = logits.size(-1) k = min(self.k, vocab_size) # Ensure k doesn't exceed vocabulary size # Get the top k values and indices top_k_values, top_k_indices = torch.topk(logits, k, dim=-1) # Create a new tensor of zeros with the same shape as logits filtered_logits = torch.full_like(logits, float("-inf")) # Set the top k values in their original positions filtered_logits.scatter_(1, top_k_indices, top_k_values) return filtered_logits
[docs] def sample( self, model: Any, input_ids: torch.Tensor, max_length: int = 100, num_return_sequences: int = 1, **kwargs, ) -> torch.Tensor: """ Generate text using Top-K sampling. Args: model: The language model to sample from input_ids: Input token IDs max_length: Maximum length of the generated sequence num_return_sequences: Number of sequences to return **kwargs: Additional arguments Returns: torch.Tensor: Generated token IDs """ generated = input_ids.clone() for _ in range(max_length): logits = self._get_logits(model, generated) logits = self._apply_sampling(logits) next_tokens = self._sample_from_logits(logits, num_samples=1) generated = torch.cat([generated, next_tokens], dim=1) # Check if all sequences have generated an EOS token if (next_tokens == model.config.eos_token_id).any(): break return generated