예제 #1
0
    def generate(self, input, hidden, generated_seq_len):
        """
        Arguments:
            - input: A mini-batch of input tokens (NOT sequences!)
                            shape: (batch_size)
            - hidden: The initial hidden states for every layer of the stacked RNN.
                            shape: (num_layers, batch_size, hidden_size)
            - generated_seq_len: The length of the sequence to generate.
                           Note that this can be different than the length used
                           for training (self.seq_len)
        Returns:
            - Sampled sequences of tokens
                        shape: (generated_seq_len, batch_size)
        """
        self.seq_len = 1
        samples = input.view(1, -1)
        for i in range(generated_seq_len):
            logits, hidden = self.forward(input, hidden)

            soft = F.softmax(logits)
            dist = Categorical(probs=soft).sample()
            dist = dist.view(1, -1)

            # Append output to samples
            samples = torch.cat((samples, dist), dim=0)
            # Make input to next time step
            input = dist
        return samples
예제 #2
0
    def generate(self, input, hidden, generated_seq_len):
        gen_samples = input.view(1, -1)
        # embedded_inp shape is (1, batch_size, emb_size)
        embedded_inp = self.embedding_layer(gen_samples)

        for t in range(generated_seq_len):
            inp_x = embedded_inp[0]
            hidden_next = []
            for layer_no in range(self.num_layers):
                cur_t_out = self.recurrent_layers[layer_no](inp_x,
                                                            hidden[layer_no])
                # This is the input for next layer
                inp_x = cur_t_out
                # next hidden state
                hidden_next.append(cur_t_out)

            hidden = torch.stack(hidden_next)
            logits = self.output_layer(inp_x).detach()
            # (batch_size, vocab_size)
            softmax_probs = F.softmax(logits, dim=1)
            out_idx = Categorical(probs=softmax_probs).sample()
            # out_idx : (1, batch_size)
            out_idx = out_idx.view(1, -1)

            gen_samples = torch.cat((gen_samples, out_idx), dim=0)
            embedded_inp = self.embedding_layer(out_idx)

        return gen_samples
예제 #3
0
  def generate(self, input, hidden, generated_seq_len):
    # TODO ========================
    # Input to hidden layer - embedding of input


    input = input.long()
    samples = input.view(1, -1)         # (1, batch_size)

    # Input to hidden layer - embedding of input
    emb_input = self.embedding(samples)    # (1, batch_size, emb_size)

    # For each time step
    for t in range(generated_seq_len):

        # Next hidden layer
        new_hidden = []

        # Input at this time step for each layer
        x_t = emb_input[0]      # (batch_size, emb_size)
        prev_hidden = hidden

        for i in range(self.num_layers):


            concatenated_input = torch.cat([x_t, prev_hidden[i]], dim=1)

            r_t = torch.sigmoid(self.r_list[i](concatenated_input))
            z_t = torch.sigmoid(self.z_list[i](concatenated_input))
            
            new_h_tm1 = r_t * prev_hidden[i]
            concatenated_h_input = torch.cat([x_t, new_h_tm1], dim=1)

            hp_t = torch.tanh(self.hp_list[i](concatenated_h_input))

            h_t = (1-z_t) * (prev_hidden[i]) + (z_t * hp_t)

            x_t = h_t
            new_hidden.append(h_t)

        prev_hidden = new_hidden
        
        output = self.output_layer(x_t)

        probs = F.softmax(output, dim=1)    # (batch_size, vocab_size)
        token_out = Categorical(probs=probs).sample()   # (batch_size)
        token_out = token_out.view(1, -1)               # (1, batch_size)

        # Append output to samples
        samples = torch.cat((samples, token_out), dim=0)

        # Make input to next time step
        x_t = self.embedding(token_out)   # (1, batch_size, emb_size)

    return samples
예제 #4
0
    def generate(self, input, hidden, generated_seq_len):
        # Compute the forward pass, as in the self.forward method (above).
        # You'll probably want to copy substantial portions of that code here.
        #
        # We "seed" the generation by providing the first inputs.
        # Subsequent inputs are generated by sampling from the output distribution,
        # as described in the tex (Problem 5.3)
        # Unlike for self.forward, you WILL need to apply the softmax activation
        # function here in order to compute the parameters of the categorical
        # distributions to be sampled from at each time-step.
        """
        Arguments:
            - input: A mini-batch of input tokens (NOT sequences!)
                            shape: (batch_size)
            - hidden: The initial hidden states for every layer of the stacked RNN.
                            shape: (num_layers, batch_size, hidden_size)
            - generated_seq_len: The length of the sequence to generate.
                           Note that this can be different than the length used 
                           for training (self.seq_len)
        Returns:
            - Sampled sequences of tokens
                        shape: (generated_seq_len, batch_size)
        """
        gen_samples = input.view(1, -1)
        # embedded_inp shape is (1, batch_size, emb_size)
        embedded_inp = self.embedding_layer(gen_samples)

        for t in range(generated_seq_len):
            inp_x = embedded_inp[0]
            hidden_next = []
            for layer_no in range(self.num_layers):
                cur_t_out = self.recurrent_layers[layer_no](inp_x,
                                                            hidden[layer_no])
                # This is the input for next layer
                inp_x = cur_t_out
                # next hidden state
                hidden_next.append(cur_t_out)

            hidden = torch.stack(hidden_next)
            logits = self.output_layer(inp_x).detach()
            # (batch_size, vocab_size)
            softmax_probs = F.softmax(logits, dim=1)
            out_idx = Categorical(probs=softmax_probs).sample()
            # out_idx : (1, batch_size)
            out_idx = out_idx.view(1, -1)

            gen_samples = torch.cat((gen_samples, out_idx), dim=0)
            embedded_inp = self.embedding_layer(out_idx)

        return gen_samples
예제 #5
0
    def generate(self, input, hidden, generated_seq_len):
        self.seq_len = 1
        samples = input.view(1, -1)
        for i in range(generated_seq_len):
            logits, hidden = self.forward(input, hidden)

            soft = F.softmax(logits)
            dist = Categorical(probs=soft).sample()
            dist = dist.view(1, -1)

            # Append output to samples
            samples = torch.cat((samples, dist), dim=0)
            # Make input to next time step
            input = dist
        return samples
예제 #6
0
  def generate(self, input, hidden, generated_seq_len):
    # TODO ========================
    # Compute the forward pass, as in the self.forward method (above).
    # You'll probably want to copy substantial portions of that code here.
    # 
    # We "seed" the generation by providing the first inputs.
    # Subsequent inputs are generated by sampling from the output distribution, 
    # as described in the tex (Problem 5.3)
    # Unlike for self.forward, you WILL need to apply the softmax activation 
    # function here in order to compute the parameters of the categorical 
    # distributions to be sampled from at each time-step.

    """
    Arguments:
        - input: A mini-batch of input tokens (NOT sequences!)
                        shape: (batch_size)
        - hidden: The initial hidden states for every layer of the stacked RNN.
                        shape: (num_layers, batch_size, hidden_size)
        - generated_seq_len: The length of the sequence to generate.
                       Note that this can be different than the length used 
                       for training (self.seq_len)
    Returns:
        - Sampled sequences of tokens
                    shape: (generated_seq_len, batch_size)
    """
   
    # Input to hidden layer - embedding of input

    input = input.long()
    samples = input.view(1, -1)         # (1, batch_size)

    # Input to hidden layer - embedding of input
    emb_input = self.embedding(samples)    # (1, batch_size, emb_size)

    # For each time step
    for t in range(generated_seq_len):

        # Next hidden layer
        new_hidden = []

        # Input at this time step for each layer
        x_t = emb_input[0]      # (batch_size, emb_size)
        prev_hidden = hidden

        for i in range(self.num_layers):


            concatenated_input = torch.cat([x_t, prev_hidden[i]], dim=1)
            h_t = self.layer_list[i](concatenated_input)
            h_t = torch.tanh(h_t)
            
            x_t = h_t
            
            new_hidden.append(h_t)

        prev_hidden = new_hidden
        
        output = self.layer_list[-1](x_t)

        probs = F.softmax(output, dim=1)    # (batch_size, vocab_size)
        token_out = Categorical(probs=probs).sample()   # (batch_size)
        token_out = token_out.view(1, -1)               # (1, batch_size)

        # Append output to samples
        samples = torch.cat((samples, token_out), dim=0)

        # Make input to next time step
        x_t = self.embedding(token_out)   # (1, batch_size, emb_size)

    return samples
예제 #7
0
    def generate(self, input, hidden, generated_seq_len):
        """
        Generate a sample sequence from the GRU.

        This is similar to the forward method but instead of having ground
        truth input for each time step, you are now required to sample the token
        with maximum probability at each time step and feed it as input at the
        next time step.

        Arguments:
            - input: A mini-batch of input tokens (NOT sequences!)
                            shape: (batch_size)
            - hidden: The initial hidden states for every layer of the stacked RNN.
                            shape: (num_layers, batch_size, hidden_size)
            - generated_seq_len: The length of the sequence to generate.
                           Note that this can be different than the length used
                           for training (self.seq_len)
        Returns:
            - Sampled sequences of tokens
                        shape: (generated_seq_len, batch_size)
        """
        # TODO ========================
        if input.is_cuda:
            device = input.get_device()
        else:
            device = torch.device("cpu")
        inputs = input.unsqueeze(0)
        # Apply the Embedding layer on the input
        embed_out = self.word_embeddings(inputs)# shape (seq_len,batch_size,emb_size)

        # Create a tensor to store outputs during the Forward
        logits = torch.zeros(self.seq_len, self.batch_size, self.vocab_size).to(device)
        samples = inputs
       
            
        # For each time step
        for timestep in range(self.seq_len):
            #print(timestep)
            hidden_states = []
            # # Apply dropout on the embedding result
            ip = embed_out[0]
            if timestep ==0:
                ip = embed_out[0]
            else:
                ip = embed_out[0]

            # reset_val = torch.sigmoid(self.r[0](torch.cat([ip, hidden[0]], 1)))
            # forget_val = torch.sigmoid(self.r[0](torch.cat([ip, hidden[0]], 1)))
            # h_tilde = torch.tanh(self.h[0](torch.cat((ip, reset_val[0]*hidden[0]), dim=1)))
            # #print(h_tilde.shape)
            # #print(self.h[layer].shape)
            # hidden[0] = (1 - forget_val)*hidden[0] + forget_val*h_tilde
            # #print(hidden[layer])
            # # Apply dropout on this layer, but not for the recurrent units

            # ip2 = hidden[0].clone()
            # layer=1
            # reset_val = torch.sigmoid(self.r[layer-1](torch.cat([ip2, hidden[layer]], 1)))
            # forget_val = torch.sigmoid(self.r[layer-1](torch.cat([ip2, hidden[layer]], 1)))
            # h_tilde = torch.tanh(self.h[layer-1](torch.cat((ip2, reset_val*hidden[layer]), dim=1)))
            # #print(h_tilde.shape)
            # #print(self.h[layer].shape)
            # hidden[1] =  forget_val*h_tilde
            # #print(hidden[layer])
            # # Apply dropout on this layer, but not for the recurrent units

            # ip3 = hidden[1].clone()
            #print(hidden.shape)
            #print(self.num_layers)
            #For each layer
            #print(ip.shape)
            for layer in range(self.num_layers):
                # Calculate the hidden states
                # And apply the activation function tanh on it
                #print(layer)s
               # print(layer)
                reset_val = torch.sigmoid(self.r[layer](torch.cat([ip, hidden[layer]], 1)))
                forget_val = torch.sigmoid(self.z[layer](torch.cat([ip, hidden[layer]], 1)))
                h_tilde = torch.tanh(self.h[layer](torch.cat((ip, reset_val*hidden[layer]),1)))
                #print(h_tilde.shape)
                #print(self.h[layer].shape)
                h_t = (1 - forget_val)*hidden[layer] + forget_val*h_tilde
                #print(hidden[layer])
                # Apply dropout on this layer, but not for the recurrent units

                ip = self.dropout(h_t)
                hidden_states.append(h_t)
            # Store the output of the time step
            
            logits[timestep] = self.out_layer(ip)
        
            hidden = torch.stack(hidden_states)
            outp = F.softmax(logits[timestep], dim=-1)
            outp[:,1] = 0
            token_out = Categorical(probs=outp).sample()   # (batch_size)
            next_inp = token_out.view(1, -1)               # (1, batch_size
            #next_inp = torch.multinomial(outp, num_samples=1).squeeze()
            #print(next_inp.shape)
            samples = torch.cat((samples,next_inp),dim=1)
            #samples.append(next_inp.unsqueeze(0))
            input_ = next_inp



            # prediction = F.softmax(logits[timestep],dim=-1)
            # word_samples = torch.distributions.Categorical(prediction).sample()
            # print("Working")
            # samples = samples.to("cpu")
            # for i, s in enumerate(word_samples):
            #     samples[i][t + 1] = s
            # samples = samples.to(device)
            #token,indices = torch.max(logits[timestep],dim=-1)
            #print(indices.shape)
            #indices = indices.unsqueeze(0)
           # print(indices.shape)
            # if(timestep==0):
            #     sampled_tokens = indices
            # else:
            #     sampled_tokens = torch.cat((sampled_tokens,indices),dim=1)
           # print(sampled_tokens.shape)
            embed_out = self.word_embeddings(input_)
            #print(embed_out.shape)
            #print("Over")#print(len(samples[0][0]))
        #sampled_tokens =  torch.stack(samples)
        #print(sampled_tokens)
        #print(sampled_tokens.shape)
        sampled_tokens = samples
        return sampled_tokens.view(generated_seq_len+1,self.batch_size)
예제 #8
0
    def generate(self, inputs, hidden, generated_seq_len):
        """
        Generate a sample sequence from the RNN.

        This is similar to the forward method but instead of having ground
        truth input for each time step, you are now required to sample the token
        with maximum probability at each time step and feed it as input at the
        next time step.

        Arguments:
            - input: A mini-batch of input tokens (NOT sequences!)
                            shape: (batch_size)
            - hidden: The initial hidden states for every layer of the stacked RNN.
                            shape: (num_layers, batch_size, hidden_size)
            - generated_seq_len: The length of the sequence to generate.
                           Note that this can be different than the length used
                           for training (self.seq_len)
        Returns:
            - Sampled sequences of tokens
                        shape: (generated_seq_len, batch_size)
        """
        # TODO ========================
        if inputs.is_cuda:
            device = inputs.get_device()
        else:
            device = torch.device("cpu")
        #print(inputs.shape)
        #print(generated_seq_len)
        #print(inputs)
        #print(seq)
        # Apply the Embedding layer on the input
        print("Entering method")
        inputs = inputs.unsqueeze(0)
       # print(inputs.shape)
        #samples = torch.zeros(self.batch_size, generated_seq_len)

        # for i, s in enumerate(input):
        #     samples[i][0] = s

        #samples = samples.to(device)

        embed_out = self.embeddings(inputs)# shape (seq_len,batch_size,emb_size)
       # print(embed_out.shape)
        # Create a tensor to store outputs during the Forward
        logits = torch.zeros(self.seq_len,self.batch_size, self.vocab_size).to(device)
        #print(logits.shape)
        # sampled_tokens = torch.LongTensor().cuda()
        # sampled_tokens = sampled_tokens.unsqueeze(1)
        # For each time step
        samples = inputs
        for timestep in range(generated_seq_len):
            # Apply dropout on the embedding result
            if timestep ==0:
                input_ = embed_out[0]
            else:
                input_ = embed_out[0]

            print(embed_out.shape)
            print(input_.shape)
            # For each layer
            for layer in range(self.num_layers):
                # Calculate the hidden states
                # And apply the activation function tanh on it
               # print(input_.shape)
                #print(hidden[layer].shape)
                hidden[layer] = torch.tanh(self.layers[layer](torch.cat([input_, hidden[layer]], 1)))
                # Apply dropout on this layer, but not for the recurrent units
                input_ = self.dropout(hidden[layer])
            # Store the output of the time step
            logits[timestep] = self.out_layer(input_)
            #print(logits.shape)
           # token_probabilities = F.softmax(logits[timestep],dim=-1)
            #outp = F.softmax(logits[timestep]/10, dim=-1)
            #next_inp = torch.multinomial(outp, num_samples=1).squeeze()
            outp = F.softmax(logits[timestep], dim=-1)
            #outp[:,1] = 0
            token_out = Categorical(probs=outp).sample()   # (batch_size)
            next_inp = token_out.view(1, -1)    
            print(next_inp.shape) 
            print(samples.shape)          # (1, batch_size
            samples = torch.cat((samples,next_inp),dim=1)
            #samples.append(next_inp.unsqueeze(0))
            input_ = next_inp



            # prediction = F.softmax(logits[timestep],dim=-1)
            # word_samples = torch.distributions.Categorical(prediction).sample()
            # print("Working")
            # samples = samples.to("cpu")
            # for i, s in enumerate(word_samples):
            #     samples[i][t + 1] = s
            # samples = samples.to(device)
            #token,indices = torch.max(logits[timestep],dim=-1)
            #print(indices.shape)
            #indices = indices.unsqueeze(0)
           # print(indices.shape)
            # if(timestep==0):
            #     sampled_tokens = indices
            # else:
            #     sampled_tokens = torch.cat((sampled_tokens,indices),dim=1)
           # print(sampled_tokens.shape)
            embed_out = self.embeddings(input_)
            #print(embed_out.shape)
            #print("Over")#print(len(samples[0][0]))
        #sampled_tokens =  torch.stack(samples)
        #print(sampled_tokens)
        #print(sampled_tokens.shape)
        sampled_tokens = samples
        return sampled_tokens.view(generated_seq_len+1,self.batch_size)
예제 #9
0
    def forward(
        self,
        enc,
        in_adj_phase,
        loc_idxs,
        all_cand_idxs,
        power_1h,
        temperature=1.0,
        top_p=1.0,
        teacher_force_orders=None,
    ):
        timings = TimingCtx()
        with timings("dec.prep"):
            device = next(self.parameters()).device

            if (loc_idxs == -1).all():
                return (
                    torch.empty(*all_cand_idxs.shape[:2],
                                dtype=torch.long,
                                device=device).fill_(EOS_IDX),
                    torch.empty(*all_cand_idxs.shape[:2],
                                dtype=torch.long,
                                device=device).fill_(EOS_IDX),
                    torch.zeros(*all_cand_idxs.shape, device=device),
                )

            # embedding for the last decoded order
            order_emb = torch.zeros(enc.shape[0],
                                    self.order_emb_size,
                                    device=device)

            # power embedding, constant for each lstm step
            assert len(
                power_1h.shape) == 2 and power_1h.shape[1] == 7, power_1h.shape
            power_emb = self.power_lin(power_1h)

            # return values: chosen order idxs, candidate idxs, and logits
            all_order_idxs = []
            all_sampled_idxs = []
            all_logits = []

            order_enc = torch.zeros(enc.shape[0],
                                    81,
                                    self.order_emb_size,
                                    device=enc.device)

            self.lstm.flatten_parameters()
            hidden = (
                torch.zeros(self.lstm_layers,
                            enc.shape[0],
                            self.lstm_size,
                            device=device),
                torch.zeros(self.lstm_layers,
                            enc.shape[0],
                            self.lstm_size,
                            device=device),
            )

            # reuse same dropout weights for all steps
            dropout_in = (torch.zeros(
                enc.shape[0],
                1,
                enc.shape[2] + self.order_emb_size + self.power_emb_size,
                device=enc.device,
            ).bernoulli_(1 - self.lstm_dropout).div_(
                1 - self.lstm_dropout).requires_grad_(False))
            dropout_out = (torch.zeros(
                enc.shape[0], 1, self.lstm_size,
                device=enc.device).bernoulli_(1 - self.lstm_dropout).div_(
                    1 - self.lstm_dropout).requires_grad_(False))

            # find max # of valid cand idxs per step
            max_cand_per_step = (all_cand_idxs != EOS_IDX).sum(dim=2).max(
                dim=0).values  # [S]

            if self.relfeat_output:
                src_relfeat_w = self.order_relfeat_src_decoder_w(enc)
                dst_relfeat_w = self.order_relfeat_dst_decoder_w(enc)

        for step in range(all_cand_idxs.shape[1]):
            with timings("dec.loc_enc"):
                num_cands = max_cand_per_step[step]
                cand_idxs = all_cand_idxs[:,
                                          step, :num_cands].long().contiguous(
                                          )

                if self.avg_embedding:
                    # no attention: average across loc embeddings
                    loc_enc = torch.mean(enc, dim=1)
                else:
                    # do static attention; set alignments to:
                    # - master_alignments for the right loc_idx when not in_adj_phase
                    in_adj_phase = in_adj_phase.view(-1, 1)
                    alignments = compute_alignments(loc_idxs, step,
                                                    self.master_alignments)
                    # print('alignments', alignments.mean(), alignments.std())
                    loc_enc = torch.matmul(alignments.unsqueeze(1),
                                           enc).squeeze(1)

            with timings("dec.lstm"):
                input_list = [loc_enc, order_emb, power_emb]
                lstm_input = torch.cat(input_list, dim=1).unsqueeze(1)
                if self.training and self.lstm_dropout > 0.0:
                    lstm_input = lstm_input * dropout_in

                out, hidden = self.lstm(lstm_input, hidden)
                if self.training and self.lstm_dropout > 0.0:
                    out = out * dropout_out

                out = out.squeeze(1).unsqueeze(2)

            with timings("dec.cand_emb"):
                cand_emb = self.cand_embedding(cand_idxs)

            with timings("dec.logits"):
                logits = torch.matmul(cand_emb, out).squeeze(2)  # [B, <=469]

                if self.featurize_output:
                    # a) featurize based on one-hot features
                    cand_order_feats = self.order_feats[cand_idxs]
                    order_w = torch.cat(
                        (
                            self.order_decoder_w(cand_order_feats),
                            self.order_decoder_b(cand_order_feats),
                        ),
                        dim=-1,
                    )

                    if self.relfeat_output:
                        cand_srcs = self.order_srcs[cand_idxs]
                        cand_dsts = self.order_dsts[cand_idxs]

                        # b) featurize based on the src and dst encoder features
                        self.get_order_loc_feats(cand_srcs, src_relfeat_w,
                                                 order_w)
                        self.get_order_loc_feats(cand_dsts, dst_relfeat_w,
                                                 order_w)

                        # c) featurize based on the src and dst order embeddings
                        self.get_order_loc_feats(
                            cand_srcs,
                            order_enc,
                            order_w,
                            enc_lin=self.order_emb_relfeat_src_decoder_w,
                        )
                        self.get_order_loc_feats(
                            cand_dsts,
                            order_enc,
                            order_w,
                            enc_lin=self.order_emb_relfeat_dst_decoder_w,
                        )

                    # add some ones to out so that the last element of order_w is a bias
                    out_with_ones = torch.cat(
                        (out,
                         torch.ones((out.shape[0], 1, 1), device=out.device)),
                        dim=1)
                    order_scores_featurized = torch.bmm(order_w, out_with_ones)
                    logits += order_scores_featurized.squeeze(-1)

            with timings("dec.invalid_mask"):
                # unmask where there are no actions or the sampling will crash. The
                # losses at these points will be masked out later, so this is safe.
                invalid_mask = ~(cand_idxs != EOS_IDX).any(dim=1)
                if invalid_mask.all():
                    # early exit
                    logging.debug(
                        f"Breaking at step {step} because no more orders to give"
                    )
                    for _step in range(
                            step, all_cand_idxs.shape[1]):  # fill in garbage
                        all_order_idxs.append(
                            torch.empty(
                                all_cand_idxs.shape[0],
                                dtype=torch.long,
                                device=all_cand_idxs.device,
                            ).fill_(EOS_IDX))
                        all_sampled_idxs.append(
                            torch.empty(
                                all_cand_idxs.shape[0],
                                dtype=torch.long,
                                device=all_cand_idxs.device,
                            ).fill_(EOS_IDX))
                    break

                cand_mask = cand_idxs != EOS_IDX
                cand_mask[invalid_mask] = 1

            with timings("dec.logits_mask"):
                # make logits for invalid actions a large negative
                logits = torch.min(logits,
                                   cand_mask.float() * 1e9 + LOGIT_MASK_VAL)
                all_logits.append(logits)

            with timings("dec.logits_temp_top_p"):
                with torch.no_grad():
                    filtered_logits = logits.detach().clone()
                    top_p_min = top_p.min().item() if isinstance(
                        top_p, torch.Tensor) else top_p
                    if top_p_min < 0.999:
                        filtered_logits.masked_fill_(
                            top_p_filtering(filtered_logits, top_p=top_p),
                            -1e9)
                    filtered_logits /= temperature

            with timings("dec.sample"):
                sampled_idxs = Categorical(logits=filtered_logits).sample()
                all_sampled_idxs.append(sampled_idxs)

            with timings("dec.order_idxs"):
                order_idxs = torch.gather(cand_idxs, 1,
                                          sampled_idxs.view(-1, 1)).view(-1)
                all_order_idxs.append(order_idxs)

            with timings("dec.order_emb"):
                order_input = (teacher_force_orders[:, step]
                               if teacher_force_orders is not None
                               else order_idxs.masked_fill(
                                   order_idxs == EOS_IDX, 0))

                order_emb = self.order_embedding(order_input)
                if self.featurize_output:
                    order_emb += self.order_feat_lin(
                        self.order_feats[order_input])

                if self.relfeat_output:
                    order_enc = order_enc + order_emb[:,
                                                      None] * alignments[:, :,
                                                                         None]

        with timings("dec.fin"):
            stacked_order_idxs = torch.stack(all_order_idxs, dim=1)
            stacked_sampled_idxs = torch.stack(all_sampled_idxs, dim=1)
            stacked_logits = cat_pad_sequences(
                [x.unsqueeze(1) for x in all_logits],
                seq_dim=2,
                cat_dim=1,
                pad_value=LOGIT_MASK_VAL,
            )
            r = stacked_order_idxs, stacked_sampled_idxs, stacked_logits

        logging.debug(f"Timings[dec, {enc.shape[0]}x{step}] {timings}")

        return r