This Uber Machine Learning Engineer phone screen asks you to implement multi-head self-attention in PyTorch from scratch. The interviewer is usually checking whether you understand:
Implement the multi-head self-attention mechanism using PyTorch. The multi-head self-attention mechanism takes in a sequence of vectors and outputs a sequence of vectors. It does this by first computing queries, keys, and values for each vector in the input sequence. Then, it applies the scaled dot-product attention mechanism to these queries, keys, and values. Finally, it concatenates the outputs of the different attention heads and applies a linear transformation to the concatenated output.
(sequence_length, batch_size, embedding_dim).Here is an example of how the multi-head self-attention mechanism might be used:
`python import torch import torch.nn.functional as F
input_seq = torch.randn(10, 32, 512) # sequence_length=10, batch_size=32, embedding_dim=512
num_heads = 8
def multi_head_self_attention(input_seq, num_heads): # Compute queries, keys, and values queries = input_seq keys = input_seq values = input_seq
# Compute the attention scores
attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) / (input_seq.shape[-1] ** 0.5)
# Compute the attention weights
attention_weights = F.softmax(attention_scores, dim=-1)
# Compute the context vectors
context_vectors = torch.matmul(attention_weights, values)
# Concatenate the context vectors and apply a linear transformation
output_seq = torch.cat([context_vectors] * num_heads, dim=-1)
# Apply a linear transformation to the concatenated output
output_seq = torch.nn.Linear(input_seq.shape[-1], input_seq.shape[-1])(output_seq)
return output_seq
output_seq = multi_head_self_attention(input_seq, num_heads) `
torch.matmul function to compute the dot product of two tensors.torch.nn.Linear module to apply a linear transformation to a tensor.torch.nn.functional.softmax function to compute the softmax of a tensor.Here is a possible solution to the problem:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embedding_dim, num_heads):
super(MultiHeadSelfAttention, self).__init__()
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.head_dim = embedding_dim // num_heads
assert self.head_dim * num_heads == embedding_dim, "Embedding dimension must be divisible by number of heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(embedding_dim, embedding_dim)
def forward(self, x):
batch_size, sequence_length, embedding_dim = x.shape
# Split the embedding into self.num_heads different pieces
x = x.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
# Permute the tensor to prepare for the self-attention operation
x = x.permute(0, 2, 1, 3)
# Compute queries, keys, and values
queries = self.queries(x)
keys = self.keys(x)
values = self.values(x)
# Compute the attention scores
attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) / (self.head_dim ** 0.5)
# Compute the attention weights
attention_weights = F.softmax(attention_scores, dim=-1)
# Compute the context vectors
context_vectors = torch.matmul(attention_weights, values)
# Concatenate the context vectors and apply a linear transformation
context_vectors = context_vectors.permute(0,