Ejemplo n.º 1
0
    def __init__(self, d_v):
        """Constructor

        :param d_v: Dimensionality of the positional encoding
        """
        super(PositionalEncoding, self).__init__()

        self.d_v = d_v

        self.denoms = as_variable(
            torch.pow(10000, 2 * torch.arange(0, d_v / 2) / d_v))
Ejemplo n.º 2
0
    def forward(self, v, seq_mask=None):
        """Forward method

        :param v: Value tensor [B x T_i x d_v x H x W FloatTensor]
        :param seq_mask: The mask for items in the sequence that can be attended to [B x T_i ByteTensor]
        :return: B x T_i x d_v x H x W
        """

        B, T_i, _, H, W = v.shape
        use_cuda = module_is_cuda(self)

        # Prepare full mask over inputs
        if seq_mask is None:
            seq_mask = torch.ones(B, T_i)
            if use_cuda:
                seq_mask = seq_mask.cuda()
            seq_mask = as_variable(seq_mask)

        # Compute masked v
        expanded_seq_mask = seq_mask.contiguous().view(B, T_i, 1, 1, 1).expand(
            B, T_i, self.d_v, H, W)
        masked_v = expanded_seq_mask * v

        # Compute keys and queries
        q = masked_v.mean(-1).mean(-1)
        # Pass block input through multi-head attention (note: masked values will ignored due to prod_mask)
        prod_mask = seq_mask.view(B, 1, T_i).expand(B, T_i, T_i)
        mha_output = self.mha_module(
            v, q, q, prod_mask=prod_mask)  # B x T_i x d_v x H x W
        # Apply skip connection and normalization
        ff_input = self.apply_batch_norm(masked_v + mha_output)

        # Pass output through feed-forward module
        ff_input_combined_B_T_i = ff_input.contiguous().view(
            -1, self.d_v, H, W)  # B*T_i x d_v x H x W
        ff_output_combined_B_T_i = self.ff_module(
            ff_input_combined_B_T_i)  # B*T_i x d_v x H x W
        ff_output = ff_output_combined_B_T_i.view(B, T_i, self.d_v, H,
                                                  W)  # B x T_i x d_v x H x W
        # Apply skip connection and layer normalization
        block_output = self.apply_batch_norm(ff_input + ff_output)

        return block_output
    def forward_train(self, preceding_frames, middle_frames, following_frames):
        """Forward method used during training. This has access to the middle frames, so it can do a single forward
        pass to compute all next frames.

        :param preceding_frames: The frames before the sequence to predict (B x K x C x H x W)
        :param middle_frames: The frames to predict (B x T x C x H x W)
        :param following_frames: The frames after the sequence to predict (B x F x C x H x W)
        :return: B x F x C x H x W
        """

        B, K, _, H, W = preceding_frames.shape
        T = middle_frames.shape[1]
        F = following_frames.shape[1]
        use_cuda = module_is_cuda(self)

        # Create input mask
        encoder_input_mask = as_variable(torch.ones(B, K + F))
        if use_cuda:
            encoder_input_mask = encoder_input_mask.cuda()

        # Create input time steps [B x K+F]
        encoder_time_input = torch.cat(
            [torch.arange(0, K),
             torch.arange(K + T, K + T + F)]).view(1, K + F).expand(B, K + F)
        encoder_time_input = as_variable(encoder_time_input)
        if use_cuda:
            encoder_time_input = encoder_time_input.cuda()

        # Combine preceding and following frame sequences
        input_frames = torch.cat([preceding_frames, following_frames],
                                 dim=1)  # B x K+F x C x H x W
        # Encode the input frames
        encoder_input_reps = self.forward_frame_encoder(input_frames)
        encoder_input = encoder_input_reps[-1]
        encoder_outputs = self.encoder(
            encoder_input, encoder_input_mask,
            encoder_time_input)  # B x K+F x d_v x H x W

        # Create start token
        start_token = torch.zeros(B, 1, self.C, H, W)
        start_token = as_variable(start_token)
        if use_cuda:
            start_token = start_token.cuda()

        # Encode inputs to the decoder. Skip the last middle frame
        frame_encoder_input = torch.cat(
            [start_token, middle_frames[:, :-1, :, :, :]], dim=1)
        dec_input_frame_reps = self.forward_frame_encoder(frame_encoder_input)

        # Create decoder time steps [B x T]
        dec_time_input = torch.arange(K, K + T).view(1, T).expand(B, T)
        dec_time_input = as_variable(dec_time_input)
        if use_cuda:
            dec_time_input = dec_time_input.cuda()

        # Create decoder product mask
        dec_prod_mask = torch.tril(torch.ones(T, T)).view(1, T,
                                                          T).expand(B, T, T)
        dec_prod_mask = as_variable(dec_prod_mask)
        if use_cuda:
            dec_prod_mask = dec_prod_mask.cuda()

        # Pass information through the decoder
        decoder_output = self.decoder(encoder_outputs, encoder_input_mask,
                                      dec_input_frame_reps[-1], dec_time_input,
                                      dec_prod_mask)
        # Pass self-attention decoder outputs through image decoder
        output_reps = self.forward_frame_decoder(decoder_output,
                                                 dec_input_frame_reps)

        return {'pred': output_reps[-1]}
    def forward(self, T, preceding_frames, following_frames):
        """Forward method

        :param T: The number of new frames to generate
        :param preceding_frames: The frames before the sequence to predict (B x K x C x H x W)
        :param following_frames: The frames after the sequence to predict (B x F x C x H x W)
        :return: B x F x C x H x W
        """

        B, K, _, H, W = preceding_frames.shape
        F = following_frames.shape[1]
        use_cuda = module_is_cuda(self)

        # Create input mask
        encoder_input_mask = as_variable(torch.ones(B, K + F))
        if use_cuda:
            encoder_input_mask = encoder_input_mask.cuda()

        # Create input time steps [B x K+F]
        encoder_time_input = torch.cat([torch.arange(0, K), torch.arange(K + T, K + T + F)]) \
            .view(1, K + F).expand(B, K + F)
        encoder_time_input = as_variable(encoder_time_input)
        if use_cuda:
            encoder_time_input = encoder_time_input.cuda()

        # Combine preceding and following frame sequences
        input_frames = torch.cat([preceding_frames, following_frames],
                                 dim=1)  # B x K+F x C x H x W
        # Encode the input frames [B x K+F x d_v x H x W
        encoder_input_reps = self.forward_frame_encoder(input_frames)
        encoder_outputs = self.encoder(encoder_input_reps[-1],
                                       encoder_input_mask, encoder_time_input)

        # Create start token
        start_token = torch.zeros(B, 1, self.C, H, W)
        start_token = as_variable(start_token)
        if use_cuda:
            start_token = start_token.cuda()

        # Encode inputs to the decoder
        dec_input_frame_reps = self.forward_frame_encoder(start_token)

        # Create decoder time steps [B x T]
        dec_time_input_full = torch.arange(K, K + T).view(1, T).expand(B, T)
        dec_time_input_full = as_variable(dec_time_input_full)
        if use_cuda:
            dec_time_input_full = dec_time_input_full.cuda()

        # Create decoder product mask
        dec_prod_mask_full = torch.tril(torch.ones(T,
                                                   T)).view(1, T,
                                                            T).expand(B, T, T)
        dec_prod_mask_full = as_variable(dec_prod_mask_full)
        if use_cuda:
            dec_prod_mask_full = dec_prod_mask_full.cuda()

        # Pass information through the decoder
        decoder_output = self.decoder(encoder_outputs, encoder_input_mask,
                                      dec_input_frame_reps[-1],
                                      dec_time_input_full, dec_prod_mask_full)

        # Pass self-attention decoder outputs through image decoder
        output_reps = self.forward_frame_decoder(decoder_output,
                                                 dec_input_frame_reps)

        return {'pred': output_reps[-1]}
Ejemplo n.º 5
0
    def forward(self,
                q_dec,
                kv_dec,
                kv_enc,
                enc_seq_mask=None,
                prod_mask=None):
        """Forward method

        Note: q_dec and kv_dec are separate because you sometimes have to decode a different number of steps from
        what you computed so far. A concrete example is decoding one step at a time (you have to attend to all
        previously-generated outputs [T_o_old], but you're only producing one output at a time [T_o_new]).

        :param q_dec: The term used to compute Q in the decoder input MHA module [B x T_o_new x d_v x H x W]
        :param kv_dec: The term used for both K and V in the decoder input MHA module [B x T_o_old x d_v x H x W]
        :param kv_enc: The term used for both K and V in the combined encoder-decoder MHA module [B x T_i x d_v x H x W]
        :param enc_seq_mask: The mask for items in the encoder output that can be attended to [B x T_i ByteTensor]
        :param prod_mask: Binary mask of values that can be attended to in the decoder input
                          [B x T_o_new x T_o_old ByteTensor]. If entry (b, i, j) is 0, then the i'th output cannot
                          depend on the j'th decoder input for the b'th item.
        :return: B x T_o_new x d_v x H x W
        """

        B, T_o_new, _, H, W = q_dec.shape
        _, T_i, _, _, _ = kv_enc.shape
        use_cuda = module_is_cuda(self)

        # Prepare sequence mask if none is provided
        if enc_seq_mask is None:
            enc_seq_mask = torch.ones(B, T_i)
            if use_cuda:
                enc_seq_mask = enc_seq_mask.cuda()
            enc_seq_mask = as_variable(enc_seq_mask)
        # Expand encoding sequence mask into a product mask
        enc_prod_mask = enc_seq_mask.view(B, 1, T_i).expand(B, T_o_new, T_i)

        # Apply MHA module on decoder inputs
        q_dec_vec = q_dec.mean(-1).mean(-1)  # B x T_o_new x d_v
        kv_dec_vec = kv_dec.mean(-1).mean(-1)  # B x T_o_old x d_v
        # B x T_o_new x d_v x H x W
        dec_only_mha_output = self.dec_only_mha_module.forward(
            kv_dec, kv_dec_vec, q_dec_vec, prod_mask=prod_mask)
        # Apply skip connection and normalization
        comb_enc_dec_mha_input = self.apply_batch_norm(
            dec_only_mha_output + q_dec)  # B x T_o_new x d_v x H x W

        # Apply MHA module to combine decoder and encoder information
        kv_enc_vec = kv_enc.mean(-1).mean(-1)  # B x T_i x d_v
        comb_enc_dec_mha_input_vec = comb_enc_dec_mha_input.mean(-1).mean(
            -1)  # B x T_o_new x d_v
        comb_enc_dec_mha_output = self.comb_enc_dec_mha_module(
            kv_enc,
            kv_enc_vec,
            comb_enc_dec_mha_input_vec,
            prod_mask=enc_prod_mask)
        ff_input = self.apply_batch_norm(comb_enc_dec_mha_output +
                                         comb_enc_dec_mha_input)

        # Pass output through feed-forward module
        ff_input_combined_B_T_i = ff_input.contiguous().view(
            -1, self.d_v, H, W)  # B*T_o_new x d_v x H x W
        ff_output_combined_B_T_i = self.ff_module(
            ff_input_combined_B_T_i)  # B*T_o_new x d_v x H x W
        ff_output = ff_output_combined_B_T_i.view(
            B, T_o_new, self.d_v, H, W)  # B x T_o_new x d_v x H x W
        # Apply skip connection and layer normalization
        block_output = self.apply_batch_norm(ff_input + ff_output)

        return block_output