Implement beam search, a decoding algorithm widely used in sequence-to-sequence models (machine translation, text generation, speech recognition). Given a starting token sequence, a function that returns the next-token probability distribution, and search parameters, produce the top-scoring output sequences.
Your code must pass all provided unit tests.
` from typing import List, Callable
def beam_search( input_seq: List[int], next_token_fn: Callable[[List[int]], List[float]], max_token: int, beam_size: int, stop_word_id: int ) -> List[List[int]]: pass `
input_seq: Initial token sequence (the prompt). Generated tokens are appended after this prefix
next_token_fn: Takes a token sequence and returns a probability distribution over the vocabulary. next_token_fn(seq)[i] is the probability of token i being next
max_token: Maximum number of new tokens to generate
beam_size: Number of candidate sequences (beams) to keep at each step
stop_word_id: Token ID that signals end of generation. A sequence is "completed" once it generates this token
List of sequences (each including the input_seq prefix), sorted by cumulative log-probability from highest to lowest. Completed sequences (ending with stop_word_id) are preferred; if none exist, return the best active beams.
Your implementation must pass all provided unit tests
Use log-probability to avoid numerical underflow
Higher (less negative) log-probability = better sequence
Beam search explores multiple paths, avoiding greedy pitfalls
Beam Search vs. Greedy vs. ExhaustiveStrategyCandidates KeptQualityCostGreedy1 (best at each step)Can miss optimal pathsO(T × V)Beam SearchB (top-B at each step)Balances quality and costO(T × B × V)ExhaustiveAllOptimalO(V^T) (intractable)
Token A has probability 0.4 at step 1, leads to stop with probability 0.9
Token B has probability 0.5 at step 1, leads to stop with probability 0.3
Greedy picks B (0.5 > 0.4), yielding score log(0.5) + log(0.3) ≈ -1.90
Beam search with width >= 2 explores path A, finding score log(0.4) + log(0.9) ≈ -1.02 — significantly better!
Initialize: One beam containing input_seq with score 0.0
For each generation step (up to max_token):
Expand: For each active beam, get next_token_fn(seq) and create a candidate for each token
Score: Each candidate's score = parent's score + log(p_token)
Separate: Candidates ending with stop_word_id go to completed list
Prune: Keep top beam_size active candidates (by score)
Return: Completed sequences sorted by score; if none, use active beams
We accumulate log-probability instead of raw probability:
Avoids numerical underflow from multiplying many small values
log(p1 × p2 × ... × pn) = log(p1) + log(p2) + ... + log(pn)
Higher (less negative) log-probability = better sequence
`
def simple_next_token(seq): if len(seq) >= 2: return [0.0, 0.0, 1.0] # force stop return [0.3, 0.5, 0.2]
result = beam_search( input_seq=[0], next_token_fn=simple_next_token, max_token=2, beam_size=2, stop_word_id=2 )
`
Maintain separate lists for active beams and completed sequences
Skip tokens with probability ≤= 0 to avoid log(0) errors
Use list slicing (input_seq[:]) to create a copy, not a reference
Sort candidates by score in descending order (highest = best)