Practice/Perplexity/Beam Search Decoding
CodingMust
Implement beam search, a decoding algorithm widely used in sequence-to-sequence models (machine translation, text generation, speech recognition). Given a starting token sequence, a function that returns the next-token probability distribution, and search parameters, produce the top-scoring output sequences.
Your code must pass all provided unit tests.
` from typing import List, Callable
def beam_search( input_seq: List[int], next_token_fn: Callable[[List[int]], List[float]], max_token: int, beam_size: int, stop_word_id: int ) -> List[List[int]]: pass `
List of sequences (each including the input_seq prefix), sorted by cumulative log-probability from highest to lowest. Completed sequences (ending with stop_word_id) are preferred; if none exist, return the best active beams.
Beam Search vs. Greedy vs. ExhaustiveStrategyCandidates KeptQualityCostGreedy1 (best at each step)Can miss optimal pathsO(T × V)Beam SearchB (top-B at each step)Balances quality and costO(T × B × V)ExhaustiveAllOptimalO(V^T) (intractable)
Consider:
Greedy picks B (0.5 > 0.4), yielding score log(0.5) + log(0.3) ≈ -1.90
Beam search with width >= 2 explores path A, finding score log(0.4) + log(0.9) ≈ -1.02 — significantly better!
Initialize: One beam containing input_seq with score 0.0
For each generation step (up to max_token):
Return: Completed sequences sorted by score; if none, use active beams
We accumulate log-probability instead of raw probability:
`
def simple_next_token(seq): if len(seq) >= 2: return [0.0, 0.0, 1.0] # force stop return [0.3, 0.5, 0.2]
result = beam_search( input_seq=[0], next_token_fn=simple_next_token, max_token=2, beam_size=2, stop_word_id=2 )
`
``
import math
from typing import List, Callable
def beam_search(
input_seq: List[int],
next_token_fn: Callable[[List[int]], List[float]],
max_token: int,
beam_size: int,
stop_word_id: int
) -> List[List[int]]:
# Each beam: (sequence, cumulative_log_prob)
beams = [(input_seq[:], 0.0)]
completed = []
for _ in range(max_token):
all_candidates = []
for seq, score in beams:
# Get probability distribution for next token
probs = next_token_fn(seq)
# Expand this beam with each possible next token
for token_id, prob in enumerate(probs):
if prob <= 0:
continue
new_seq = seq + [token_id]
new_score = score + math.log(prob)
# Separate completed from active candidates
if token_id == stop_word_id:
completed.append((new_seq, new_score))
else:
all_candidates.append((new_seq, new_score))
# If no active candidates remain, stop
if not all_candidates:
break
# Keep only top beam_size active candidates
all_candidates.sort(key=lambda x: x[1], reverse=True)
beams = all_candidates[:beam_size]
# If no sequence completed with stop word, return best active beams
if not completed:
completed = beams
# Sort by score (highest first) and return sequences only
completed.sort(key=lambda x: x[1], reverse=True)
return [seq for seq, _ in completed]