def forward(self, device, initial_hidden, name_targets, max_name_len, start_token, batch_size, logit_modifier_fxn, token_sampler, **kwargs): use_teacher_forcing = name_targets is not None decoder_hidden = initial_hidden logit_probs = [] output_tokens = [] for i in range(max_name_len): if use_teacher_forcing: input_token = name_targets[:, i].unsqueeze(1) # Non-teacher forcing - initialize with START; otherwise use previous input elif i == 0: input_token = torch.LongTensor( [start_token] * batch_size).unsqueeze(1).to(device) # Project input to vocab space input_embed = self.vocab_embedding(input_token) rnn_output, decoder_hidden = self.rnn(input_embed, decoder_hidden) logits = self.proj(rnn_output) logit_prob = F.log_softmax(logits, dim=-1) logit_probs.append(logit_prob) # by default greedy logit_modifier_fxn = partial(top_k_logits, k=0) token_sampler = 'greedy' input_token = sample_next_token( logits, logit_modifier_fxn=logit_modifier_fxn, sampler=token_sampler) output_tokens.append(input_token) # Return logit probabilities in tensor form logit_probs = torch.cat(logit_probs, dim=1) # Concatenate along step dimension for visualizations output_tokens = torch.cat(output_tokens, dim=1) return logit_probs, output_tokens
def forward(self, device, initial_hidden, calorie_encoding, name_encoding, ingr_encodings, ingr_masks, targets, user_items, user_item_names, user_item_masks, max_len, batch_size, start_token, logit_modifier_fxn, token_sampler, visualize=False): """ Forward pass over a batch, unrolled over all timesteps Arguments: device {torch.device} -- Torch device initial_hidden {torch.Tensor} -- Initial hidden state for the decoder RNN [L; B; H] calorie_encoding {torch.Tensor} -- Calorie level encoding [B; H] name_encoding {torch.Tensor} -- name gru encodings [B; Nt; H] ingr_encodings {torch.Tensor} -- MLP-encodings for each ingredient in recipe [B; Ni; H] ingr_masks {torch.Tensor} -- Positional binary mask of non-pad ingredients in recipe [B; Ni] targets {torch.Tensor} -- Target (gold) token indices. If provided, will teacher-force [B; T; V] user_items {torch.Tensor} -- Indices of prior items encountered per user (most recent k) [B; Most-recent-k] user_item_masks {torch.Tensor} -- Positional binary mask of user encountered items in most recent k. [B; Most-recent-k] max_len {int} -- Unroll to a maximum of this many timesteps batch_size {int} -- Number of examples in a batch start_token {int} -- Start token index to use as initial input for non-teacher-forcing Keyword Arguments: sample_next_token {func} -- Function to select the next token from a set of logit probs. Only used if not teacher-forcing. (default: partial(top_k_logits, k=0) with sampler='greedy') visualize {bool} -- Whether to accumulate items for visualization (default: {False}) Returns: torch.Tensor -- Logit probabilities for each step in the batch [B; T; V] torch.Tensor -- Output tokens /step /batch [B; T] {Optional tensors if visualizing} torch.Tensor -- Technique attention weights /step /batch [B; T; Nt] torch.Tensor -- Positional ingredient attention weights /step /batch [B; T; Ni] """ # Initialize variables logit_probs = [] use_teacher_forcing = targets is not None input_token = None decoder_hidden = initial_hidden # Accumulation of attention weights ingr_attns_for_plot = [] prior_item_attns_for_plot = [] output_tokens = [] # Key projections ingr_proj_key = self.ingr_attention.key_layer(ingr_encodings) if self.item_embedding is None: # Prior item key projections # user_name_items shape B, MAX_NAME, prior_item_values = torch.mean(self.vocab_embedding(user_item_names), dim=-2) prior_item_keys = self.prior_item_attention.key_layer(prior_item_values) else: # Prior item key projections prior_item_values = self.item_embedding.weight[user_items] prior_item_keys = self.prior_item_attention.key_layer(prior_item_values) # Unroll the decoder RNN for max_len steps for i in range(max_len): # Teacher forcing - use prior target token if use_teacher_forcing: input_token = targets[:, i].unsqueeze(1) # Non-teacher forcing - initialize with START; otherwise use previous input elif i == 0: input_token = torch.LongTensor([start_token] * batch_size).unsqueeze(1).to(device) # Project input to vocab space input_embed = self.vocab_embedding(input_token) # Query -> decoder hidden state query = decoder_hidden[-1].unsqueeze(1) # [#layers, B, D] -> [B, 1, D] # Current item ingredient attention ingr_context, ingr_alpha = self.ingr_attention( query=query, proj_key=ingr_proj_key, value=ingr_encodings, mask=ingr_masks ) if visualize: ingr_attns_for_plot.append(ingr_alpha) # Prior item attention prior_item_context, prior_item_alpha = self.prior_item_attention( query=query, proj_key=prior_item_keys, value=prior_item_values, mask=user_item_masks ) if visualize: prior_item_attns_for_plot.append(prior_item_alpha) # Take a single step _, decoder_hidden, pre_output = self.forward_step( input_embed=input_embed, decoder_hidden=decoder_hidden, calorie_encoding=calorie_encoding.unsqueeze(1), name_encoding=name_encoding[-1].unsqueeze(1), context=[ingr_context], attention_fusion_context=prior_item_context ) # Project output to vocabulary space logits = self.proj(pre_output) logit_prob = F.log_softmax(logits, dim=-1) # Debug segment if torch.sum(torch.isnan(logit_prob)) > 0: print('NAN DETECTED') print('INPUT:\n{}'.format(input_token)) print('NAME ENCODING:\n{}'.format(name_encoding)) print('INGR CONTEXT:\n{}'.format(ingr_context)) print('USER_ITEM:\n{}'.format(user_item_names)) print('USER_ITEM_MASK:\n{}'.format(user_item_masks)) print('PRIOR_CONTEXT:\n{}'.format(prior_item_context)) print('CALORIE ENCODING:\n{}'.format(calorie_encoding)) raise Exception('NAN DETECTED') logit_probs.append(logit_prob) # Save input token for next iteration (if not teacher-forcing) if not use_teacher_forcing: input_token = sample_next_token( logits, logit_modifier_fxn=logit_modifier_fxn, sampler=token_sampler ) output_tokens.append(input_token) # Return logit probabilities in tensor form logit_probs = torch.cat(logit_probs, dim=1) # Concatenate along step dimension for visualizations if not use_teacher_forcing: output_tokens = torch.cat(output_tokens, dim=1) if visualize: ingr_attns_for_plot, prior_item_attns_for_plot = [ torch.cat(tensors, dim=1) for tensors in [ ingr_attns_for_plot, prior_item_attns_for_plot ] ] return logit_probs, output_tokens, ingr_attns_for_plot, \ prior_item_attns_for_plot return logit_probs, output_tokens
def forward(self, device, initial_hidden, calorie_encoding, name_encoding, ingr_encodings, ingr_masks, targets, user_prior_technique_masks, max_len, batch_size, start_token, logit_modifier_fxn, token_sampler, visualize=False): """ Forward pass over a batch, unrolled over all timesteps Arguments: device {torch.device} -- Torch device initial_hidden {torch.Tensor} -- Initial hidden state for the decoder RNN [L; B; H] calorie_encoding {torch.Tensor} -- Calorie level encoding [B; H] name_encoding {torch.Tensor} -- Recipe encoding for final name [L; B; H] ingr_encodings {torch.Tensor} -- MLP-encodings for each ingredient in recipe [B; Ni; H] ingr_masks {torch.Tensor} -- Positional binary mask of non-pad ingredients in recipe [B; Ni] targets {torch.Tensor} -- Target (gold) token indices. If provided, will teacher-force [B; T; V] user_prior_technique_masks {torch.Tensor} -- Vector representing user's normalized exposure to each technique [B; Nt] max_len {int} -- Unroll to a maximum of this many timesteps batch_size {int} -- Number of examples in a batch start_token {int} -- Start token index to use as initial input for non-teacher-forcing Keyword Arguments: sample_next_token {func} -- Function to select the next token from a set of logit probs. Only used if not teacher-forcing. (default: partial(top_k_logits, k=0) with sampler='greedy') visualize {bool} -- Whether to accumulate items for visualization (default: {False}) Returns: torch.Tensor -- Logit probabilities for each step in the batch [B; T; V] torch.Tensor -- Output tokens /step /batch [B; T] {Optional tensors if visualizing} torch.Tensor -- Positional ingredient attention weights /step /batch [B; T; Ni] torch.Tensor -- Prior technique attention weights /step /batch [B; T; Nt] """ # Initialize variables logit_probs = [] use_teacher_forcing = targets is not None input_token = None decoder_hidden = initial_hidden # Accumulation of attention weights ingr_attns_for_plot = [] prior_tech_attns_for_plot = [] output_tokens = [] # Key projections ingr_proj_key = self.ingr_attention.key_layer(ingr_encodings) prior_tech_proj_key = self.prior_tech_attention.key_layer( self.prior_tech_key_projection(self.technique_embedding.weight)) # Unroll the decoder RNN for max_len steps for i in range(max_len): # Teacher forcing - use prior target token if use_teacher_forcing: input_token = targets[:, i].unsqueeze(1) # Non-teacher forcing - initialize with START; otherwise use previous input elif i == 0: input_token = torch.LongTensor( [start_token] * batch_size).unsqueeze(1).to(device) # Project input to vocab space input_embed = self.vocab_embedding(input_token) # Query -> decoder hidden state query = decoder_hidden[-1].unsqueeze( 1) # [#layers, B, D] -> [B, 1, D] # Current item ingredient attention ingr_context, ingr_alpha = self.ingr_attention( query=query, proj_key=ingr_proj_key, value=ingr_encodings, mask=ingr_masks) if visualize: ingr_attns_for_plot.append(ingr_alpha) # Prior technique exposure attention tech_embed_values = torch.stack([self.technique_embedding.weight] * batch_size, dim=0) personal_tech_context, personal_tech_alpha = self.prior_tech_attention( query=query, proj_key=prior_tech_proj_key, value=tech_embed_values, mask=user_prior_technique_masks, copy=user_prior_technique_masks, ) if visualize: prior_tech_attns_for_plot.append(personal_tech_alpha) # Take a single step _, decoder_hidden, pre_output = self.forward_step( input_embed=input_embed, decoder_hidden=decoder_hidden, name_encoding=name_encoding[-1].unsqueeze(1), calorie_encoding=calorie_encoding.unsqueeze(1), context=[ingr_context], personal_tech_context=personal_tech_context) # Project output to vocabulary space logits = self.proj(pre_output) logit_prob = F.log_softmax(logits, dim=-1) if torch.sum(torch.isnan(logit_prob)) > 0: print('!!!!!!!! NAN LOGIT DETECTED !!!!!!!!!!!!!!') for tens_name, tens in [ ('input tokens', input_token), # ('technique context', technique_context), # ('technique mask', technique_masks), # ('ingredient context', ingr_context), # ('ingredient mask', ingr_masks), ('prior tech masks', user_prior_technique_masks), # ('tech embed values', tech_embed_values), # ('query', query), # ('personal tech projection key', prior_tech_proj_key), ('personal tech context', personal_tech_context), # ('residual context', context), # ('calorie encoding', calorie_encoding), ]: print('=======================') print(tens_name) print(tens.size()) print(tens.detach().cpu().numpy().tolist()) print('Number of NaNs: {}'.format( torch.sum(torch.isnan(tens)))) raise Exception('NAN DETECTED') logit_probs.append(logit_prob) # Save input token for next iteration (if not teacher-forcing) if not use_teacher_forcing: input_token = sample_next_token( logits, logit_modifier_fxn=logit_modifier_fxn, sampler=token_sampler) output_tokens.append(input_token) # Return logit probabilities in tensor form logit_probs = torch.cat(logit_probs, dim=1) # Concatenate along step dimension for visualizations if not use_teacher_forcing: output_tokens = torch.cat(output_tokens, dim=1) if visualize: ingr_attns_for_plot, prior_tech_attns_for_plot = [ torch.cat(tensors, dim=1) for tensors in [ingr_attns_for_plot, prior_tech_attns_for_plot] ] return logit_probs, output_tokens, ingr_attns_for_plot, \ prior_tech_attns_for_plot return logit_probs, output_tokens