def forward(self, device, initial_hidden, name_targets, max_name_len,
                start_token, batch_size, logit_modifier_fxn, token_sampler,
                **kwargs):

        use_teacher_forcing = name_targets is not None
        decoder_hidden = initial_hidden
        logit_probs = []
        output_tokens = []

        for i in range(max_name_len):

            if use_teacher_forcing:
                input_token = name_targets[:, i].unsqueeze(1)
            # Non-teacher forcing - initialize with START; otherwise use previous input
            elif i == 0:
                input_token = torch.LongTensor(
                    [start_token] * batch_size).unsqueeze(1).to(device)

            # Project input to vocab space
            input_embed = self.vocab_embedding(input_token)

            rnn_output, decoder_hidden = self.rnn(input_embed, decoder_hidden)

            logits = self.proj(rnn_output)
            logit_prob = F.log_softmax(logits, dim=-1)

            logit_probs.append(logit_prob)

            # by default greedy
            logit_modifier_fxn = partial(top_k_logits, k=0)
            token_sampler = 'greedy'
            input_token = sample_next_token(
                logits,
                logit_modifier_fxn=logit_modifier_fxn,
                sampler=token_sampler)
            output_tokens.append(input_token)

        # Return logit probabilities in tensor form
        logit_probs = torch.cat(logit_probs, dim=1)
        # Concatenate along step dimension for visualizations
        output_tokens = torch.cat(output_tokens, dim=1)

        return logit_probs, output_tokens
    def forward(self, device, initial_hidden, calorie_encoding,
                name_encoding,
                ingr_encodings, ingr_masks, 
                targets, user_items, user_item_names, user_item_masks,
                max_len, batch_size, start_token,
                logit_modifier_fxn, token_sampler,
                visualize=False):
        """
        Forward pass over a batch, unrolled over all timesteps

        Arguments:
            device {torch.device} -- Torch device
            initial_hidden {torch.Tensor} -- Initial hidden state for the decoder RNN [L; B; H]
            calorie_encoding {torch.Tensor} -- Calorie level encoding [B; H]
            name_encoding {torch.Tensor} -- name gru encodings [B; Nt; H]
            ingr_encodings {torch.Tensor} -- MLP-encodings for each ingredient in recipe [B; Ni; H]
            ingr_masks {torch.Tensor} -- Positional binary mask of non-pad ingredients in recipe
                [B; Ni]
            targets {torch.Tensor} -- Target (gold) token indices. If provided, will teacher-force
                [B; T; V]
            user_items {torch.Tensor} -- Indices of prior items encountered per user (most recent k)
                [B; Most-recent-k]
            user_item_masks {torch.Tensor} -- Positional binary mask of user encountered items in most
                recent k. [B; Most-recent-k]
            max_len {int} -- Unroll to a maximum of this many timesteps
            batch_size {int} -- Number of examples in a batch
            start_token {int} -- Start token index to use as initial input for non-teacher-forcing

        Keyword Arguments:
            sample_next_token {func} -- Function to select the next token from a set of logit probs.
                Only used if not teacher-forcing. (default: partial(top_k_logits, k=0) with sampler='greedy')
            visualize {bool} -- Whether to accumulate items for visualization (default: {False})

        Returns:
            torch.Tensor -- Logit probabilities for each step in the batch [B; T; V]
            torch.Tensor -- Output tokens /step /batch [B; T]
            {Optional tensors if visualizing}
                torch.Tensor -- Technique attention weights /step /batch [B; T; Nt]
                torch.Tensor -- Positional ingredient attention weights /step /batch [B; T; Ni]
        """
        # Initialize variables
        logit_probs = []
        use_teacher_forcing = targets is not None
        input_token = None
        decoder_hidden = initial_hidden

        # Accumulation of attention weights
        ingr_attns_for_plot = []
        prior_item_attns_for_plot = []
        output_tokens = []

        # Key projections
        ingr_proj_key = self.ingr_attention.key_layer(ingr_encodings)

        if self.item_embedding is None:
            # Prior item key projections
            # user_name_items shape B, MAX_NAME, 
            prior_item_values = torch.mean(self.vocab_embedding(user_item_names), dim=-2)
            prior_item_keys = self.prior_item_attention.key_layer(prior_item_values)
        else:
            # Prior item key projections
            prior_item_values = self.item_embedding.weight[user_items]
            prior_item_keys = self.prior_item_attention.key_layer(prior_item_values)

        # Unroll the decoder RNN for max_len steps
        for i in range(max_len):
            # Teacher forcing - use prior target token
            if use_teacher_forcing:
                input_token = targets[:, i].unsqueeze(1)
            # Non-teacher forcing - initialize with START; otherwise use previous input
            elif i == 0:
                input_token = torch.LongTensor([start_token] * batch_size).unsqueeze(1).to(device)

            # Project input to vocab space
            input_embed = self.vocab_embedding(input_token)

            # Query -> decoder hidden state
            query = decoder_hidden[-1].unsqueeze(1)  # [#layers, B, D] -> [B, 1, D]

            # Current item ingredient attention
            ingr_context, ingr_alpha = self.ingr_attention(
                query=query,
                proj_key=ingr_proj_key,
                value=ingr_encodings,
                mask=ingr_masks
            )
            if visualize:
                ingr_attns_for_plot.append(ingr_alpha)

            # Prior item attention
            prior_item_context, prior_item_alpha = self.prior_item_attention(
                query=query,
                proj_key=prior_item_keys,
                value=prior_item_values,
                mask=user_item_masks
            )
            if visualize:
                prior_item_attns_for_plot.append(prior_item_alpha)

            # Take a single step
            _, decoder_hidden, pre_output = self.forward_step(
                input_embed=input_embed,
                decoder_hidden=decoder_hidden,
                calorie_encoding=calorie_encoding.unsqueeze(1),
                name_encoding=name_encoding[-1].unsqueeze(1),
                context=[ingr_context],
                attention_fusion_context=prior_item_context
            )

            # Project output to vocabulary space
            logits = self.proj(pre_output)
            logit_prob = F.log_softmax(logits, dim=-1)

            # Debug segment
            if torch.sum(torch.isnan(logit_prob)) > 0:
                print('NAN DETECTED')
                print('INPUT:\n{}'.format(input_token))
                print('NAME ENCODING:\n{}'.format(name_encoding))
                print('INGR CONTEXT:\n{}'.format(ingr_context))
                print('USER_ITEM:\n{}'.format(user_item_names))
                print('USER_ITEM_MASK:\n{}'.format(user_item_masks))
                print('PRIOR_CONTEXT:\n{}'.format(prior_item_context))
                print('CALORIE ENCODING:\n{}'.format(calorie_encoding))
                raise Exception('NAN DETECTED')

            logit_probs.append(logit_prob)

            # Save input token for next iteration (if not teacher-forcing)
            if not use_teacher_forcing:
                input_token = sample_next_token(
                    logits, logit_modifier_fxn=logit_modifier_fxn, sampler=token_sampler
                )
                output_tokens.append(input_token)

        # Return logit probabilities in tensor form
        logit_probs = torch.cat(logit_probs, dim=1)

        # Concatenate along step dimension for visualizations
        if not use_teacher_forcing:
            output_tokens = torch.cat(output_tokens, dim=1)
        if visualize:
            ingr_attns_for_plot, prior_item_attns_for_plot = [
                torch.cat(tensors, dim=1) for tensors in [
                    ingr_attns_for_plot, prior_item_attns_for_plot
                ]
            ]
            return logit_probs, output_tokens, ingr_attns_for_plot, \
                prior_item_attns_for_plot

        return logit_probs, output_tokens
Example #3
0
    def forward(self,
                device,
                initial_hidden,
                calorie_encoding,
                name_encoding,
                ingr_encodings,
                ingr_masks,
                targets,
                user_prior_technique_masks,
                max_len,
                batch_size,
                start_token,
                logit_modifier_fxn,
                token_sampler,
                visualize=False):
        """
        Forward pass over a batch, unrolled over all timesteps

        Arguments:
            device {torch.device} -- Torch device
            initial_hidden {torch.Tensor} -- Initial hidden state for the decoder RNN [L; B; H]
            calorie_encoding {torch.Tensor} -- Calorie level encoding [B; H]
            name_encoding {torch.Tensor} -- Recipe encoding for final name [L; B; H]
            ingr_encodings {torch.Tensor} -- MLP-encodings for each ingredient in recipe [B; Ni; H]
            ingr_masks {torch.Tensor} -- Positional binary mask of non-pad ingredients in recipe
                [B; Ni]
            targets {torch.Tensor} -- Target (gold) token indices. If provided, will teacher-force
                [B; T; V]
            user_prior_technique_masks {torch.Tensor} -- Vector representing user's normalized exposure
                to each technique [B; Nt]
            max_len {int} -- Unroll to a maximum of this many timesteps
            batch_size {int} -- Number of examples in a batch
            start_token {int} -- Start token index to use as initial input for non-teacher-forcing

        Keyword Arguments:
            sample_next_token {func} -- Function to select the next token from a set of logit probs.
                Only used if not teacher-forcing. (default: partial(top_k_logits, k=0) with
                sampler='greedy')
            visualize {bool} -- Whether to accumulate items for visualization (default: {False})

        Returns:
            torch.Tensor -- Logit probabilities for each step in the batch [B; T; V]
            torch.Tensor -- Output tokens /step /batch [B; T]
            {Optional tensors if visualizing}
                torch.Tensor -- Positional ingredient attention weights /step /batch [B; T; Ni]
                torch.Tensor -- Prior technique attention weights /step /batch [B; T; Nt]
        """
        # Initialize variables
        logit_probs = []
        use_teacher_forcing = targets is not None
        input_token = None
        decoder_hidden = initial_hidden

        # Accumulation of attention weights
        ingr_attns_for_plot = []
        prior_tech_attns_for_plot = []
        output_tokens = []

        # Key projections
        ingr_proj_key = self.ingr_attention.key_layer(ingr_encodings)
        prior_tech_proj_key = self.prior_tech_attention.key_layer(
            self.prior_tech_key_projection(self.technique_embedding.weight))

        # Unroll the decoder RNN for max_len steps
        for i in range(max_len):
            # Teacher forcing - use prior target token
            if use_teacher_forcing:
                input_token = targets[:, i].unsqueeze(1)
            # Non-teacher forcing - initialize with START; otherwise use previous input
            elif i == 0:
                input_token = torch.LongTensor(
                    [start_token] * batch_size).unsqueeze(1).to(device)

            # Project input to vocab space
            input_embed = self.vocab_embedding(input_token)

            # Query -> decoder hidden state
            query = decoder_hidden[-1].unsqueeze(
                1)  # [#layers, B, D] -> [B, 1, D]

            # Current item ingredient attention
            ingr_context, ingr_alpha = self.ingr_attention(
                query=query,
                proj_key=ingr_proj_key,
                value=ingr_encodings,
                mask=ingr_masks)
            if visualize:
                ingr_attns_for_plot.append(ingr_alpha)

            # Prior technique exposure attention
            tech_embed_values = torch.stack([self.technique_embedding.weight] *
                                            batch_size,
                                            dim=0)
            personal_tech_context, personal_tech_alpha = self.prior_tech_attention(
                query=query,
                proj_key=prior_tech_proj_key,
                value=tech_embed_values,
                mask=user_prior_technique_masks,
                copy=user_prior_technique_masks,
            )
            if visualize:

                prior_tech_attns_for_plot.append(personal_tech_alpha)

            # Take a single step
            _, decoder_hidden, pre_output = self.forward_step(
                input_embed=input_embed,
                decoder_hidden=decoder_hidden,
                name_encoding=name_encoding[-1].unsqueeze(1),
                calorie_encoding=calorie_encoding.unsqueeze(1),
                context=[ingr_context],
                personal_tech_context=personal_tech_context)

            # Project output to vocabulary space
            logits = self.proj(pre_output)
            logit_prob = F.log_softmax(logits, dim=-1)

            if torch.sum(torch.isnan(logit_prob)) > 0:
                print('!!!!!!!! NAN LOGIT DETECTED !!!!!!!!!!!!!!')
                for tens_name, tens in [
                    ('input tokens', input_token),
                        # ('technique context', technique_context),
                        # ('technique mask', technique_masks),
                        # ('ingredient context', ingr_context),
                        # ('ingredient mask', ingr_masks),
                    ('prior tech masks', user_prior_technique_masks),
                        # ('tech embed values', tech_embed_values),
                        # ('query', query),
                        # ('personal tech projection key', prior_tech_proj_key),
                    ('personal tech context', personal_tech_context),
                        # ('residual context', context),
                        # ('calorie encoding', calorie_encoding),
                ]:
                    print('=======================')
                    print(tens_name)
                    print(tens.size())
                    print(tens.detach().cpu().numpy().tolist())
                    print('Number of NaNs: {}'.format(
                        torch.sum(torch.isnan(tens))))
                raise Exception('NAN DETECTED')

            logit_probs.append(logit_prob)

            # Save input token for next iteration (if not teacher-forcing)
            if not use_teacher_forcing:
                input_token = sample_next_token(
                    logits,
                    logit_modifier_fxn=logit_modifier_fxn,
                    sampler=token_sampler)
                output_tokens.append(input_token)

        # Return logit probabilities in tensor form
        logit_probs = torch.cat(logit_probs, dim=1)

        # Concatenate along step dimension for visualizations
        if not use_teacher_forcing:
            output_tokens = torch.cat(output_tokens, dim=1)
        if visualize:
            ingr_attns_for_plot, prior_tech_attns_for_plot = [
                torch.cat(tensors, dim=1) for tensors in
                [ingr_attns_for_plot, prior_tech_attns_for_plot]
            ]
            return logit_probs, output_tokens, ingr_attns_for_plot, \
                prior_tech_attns_for_plot

        return logit_probs, output_tokens