Source code for llm_samplers.anti_slop

from typing import Any, List, Optional, Set

import torch

from .base import BaseSampler


[docs] class AntiSlopSampler(BaseSampler): """ Anti-Slop sampling down-weights probabilities at word & phrase level. Uses backtracking to retry with adjusted token probabilities if it encounters a disallowed word/phrase. """
[docs] def __init__( self, disallowed_tokens: Optional[Set[int]] = None, disallowed_phrases: Optional[List[List[int]]] = None, penalty: float = 0.5, max_retries: int = 3, ): """ Initialize the Anti-Slop sampler. Args: disallowed_tokens: Set of token IDs to penalize disallowed_phrases: List of token ID sequences to penalize penalty: Probability penalty factor (0-1) max_retries: Maximum number of retries for backtracking """ super().__init__() if not 0 <= penalty <= 1: raise ValueError("penalty must be between 0 and 1") if max_retries < 1: raise ValueError("max_retries must be at least 1") self.disallowed_tokens = disallowed_tokens or set() self.disallowed_phrases = disallowed_phrases or [] self.penalty = penalty self.max_retries = max_retries
def _apply_sampling(self, logits: torch.Tensor) -> torch.Tensor: """ Apply Anti-Slop filtering to the logits. Args: logits: Raw logits from the model Returns: torch.Tensor: Anti-Slop filtered logits """ filtered_logits = logits.clone() # Apply penalty to disallowed tokens if self.disallowed_tokens: for token_id in self.disallowed_tokens: filtered_logits[:, token_id] *= self.penalty return filtered_logits def _check_phrases(self, generated: torch.Tensor) -> bool: """ Check if any disallowed phrases are present in the generated sequence. Args: generated: Generated token sequence Returns: bool: True if a disallowed phrase is found """ for phrase in self.disallowed_phrases: phrase_len = len(phrase) for i in range(generated.shape[1] - phrase_len + 1): if torch.all( generated[:, i : i + phrase_len] == torch.tensor(phrase, device=generated.device) ): return True return False
[docs] def sample( self, model: Any, input_ids: torch.Tensor, max_length: int = 100, num_return_sequences: int = 1, **kwargs, ) -> torch.Tensor: """ Generate text using Anti-Slop sampling. Args: model: The language model to sample from input_ids: Input token IDs max_length: Maximum length of the generated sequence num_return_sequences: Number of sequences to return **kwargs: Additional arguments Returns: torch.Tensor: Generated token IDs """ generated = input_ids.clone() retry_count = 0 while retry_count < self.max_retries: current_generated = generated.clone() for _ in range(max_length): logits = self._get_logits(model, current_generated) logits = self._apply_sampling(logits) next_tokens = self._sample_from_logits(logits, num_samples=1) current_generated = torch.cat([current_generated, next_tokens], dim=1) # Check for disallowed phrases if self._check_phrases(current_generated): retry_count += 1 break # Check if all sequences have generated an EOS token if (next_tokens == model.config.eos_token_id).any(): return current_generated # If we've reached max_length without finding a disallowed phrase if not self._check_phrases(current_generated): return current_generated # If we've exhausted all retries, return the last generated sequence return current_generated