Exemple #1
0
    def forward(self, src, src2, output_n=25, input_n=50, itera=1):
        """
        :param src: [batch_size,seq_len,feat_dim]
        :param itera:
        :return:
        """
        subseq_nbr = input_n - output_n - self.kernel_size + 1
        subseq_length = self.kernel_size + output_n

        dct_m, idct_m = util.get_dct_matrix(subseq_length)
        dct_m = torch.from_numpy(dct_m).float().to(device=self.device)
        idct_m = torch.from_numpy(idct_m).float().to(device=self.device)

        outputs = []
        batch_size = src.shape[0]
        src, last_pose = src_reformat(
            src, input_n, substract_last_pose=False)  # [32, 50, 66]

        encode_in = src.clone()
        decode_in = src.clone()
        encode_in = self.init_lin_encode(encode_in)
        encode_in = self.pos_encoding(encode_in)
        encode_out = self.encoder(
            encode_in,
            mask=None)  # in : [32, 50, d_model] out : [32, 50, d_model]

        decode_in = self.init_lin_encode(decode_in)
        decode_in = self.pos_encoding(decode_in)
        decode_out = self.decoder(self,
                                  decode_in,
                                  encode_out,
                                  src_mask=None,
                                  tgt_mask=None)

        att_out = self.inter_lin_decode(decode_out)
        att_out = self.res_conv_block(att_out.transpose(1, 2)).transpose(1, 2)

        idx = list(range(-self.kernel_size, 0, 1)) + [
            -1
        ] * output_n  # [-10, -9, ..., -1, -1 ..., -1], indexes used for GCN input
        input_gcn = src[:, idx]

        att_out = att_out + input_gcn
        att_out = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                               att_out).transpose(1, 2)

        dct_in_tmp = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                                  input_gcn).transpose(1, 2)
        dct_in_tmp = torch.cat([dct_in_tmp, att_out], dim=-1)
        dct_out_tmp = self.gcn(dct_in_tmp)
        out_gcn = torch.matmul(idct_m[:, :self.dct_n].unsqueeze(dim=0),
                               dct_out_tmp[:, :, :self.dct_n].transpose(1, 2))

        outputs.append(out_gcn.unsqueeze(2))

        if itera > 1:
            for i in range(itera - 1):
                new_in = torch.cat(
                    [new_in[:, -input_n + output_n:], out_gcn[:, -output_n:]],
                    dim=1)
                att_in = self.init_lin(new_in)
                att_in = self.pos_encoding(att_in)
                att_out = self.encoder(
                    att_in, mask=None
                )  # in : [32, 50, d_model] out : [32, 50, d_model]

                att_out = self.inter_lin(att_out)
                att_out = self.res_conv_block(att_out.transpose(1,
                                                                2)).transpose(
                                                                    1, 2)
                input_gcn = new_in[:, idx]
                att_out = att_out + input_gcn
                att_out = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                                       att_out).transpose(1, 2)

                dct_in_tmp = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                                          input_gcn).transpose(1, 2)
                dct_in_tmp = torch.cat([dct_in_tmp, att_out], dim=-1)
                dct_out_tmp = self.gcn(dct_in_tmp)
                out_gcn = torch.matmul(
                    idct_m[:, :self.dct_n].unsqueeze(dim=0),
                    dct_out_tmp[:, :, :self.dct_n].transpose(1, 2))

                outputs.append(out_gcn[:, -output_n:].unsqueeze(2))
        outputs = torch.cat(outputs, dim=1)
        return outputs
Exemple #2
0
    def forward(self, src, output_n=25, input_n=50, itera=1):
        """

        :param src: [batch_size,seq_len,feat_dim]
        :param output_n:
        :param input_n:
        :param frame_n:
        :param dct_n:
        :param itera:
        :return:
        """
        dct_n = self.dct_n
        src = src[:, :input_n]  # [bs,in_n,dim]
        src_tmp = src.clone()
        bs = src.shape[0]
        src_key_tmp = src_tmp.transpose(1, 2)[:, :, :(input_n - output_n)].clone()
        src_query_tmp = src_tmp.transpose(1, 2)[:, :, -self.kernel_size:].clone()

        dct_m, idct_m = util.get_dct_matrix(self.kernel_size + output_n)
        dct_m = torch.from_numpy(dct_m).float().cuda()
        idct_m = torch.from_numpy(idct_m).float().cuda()

        vn = input_n - self.kernel_size - output_n + 1
        vl = self.kernel_size + output_n
        idx = np.expand_dims(np.arange(vl), axis=0) + \
              np.expand_dims(np.arange(vn), axis=1)
        src_value_tmp = src_tmp[:, idx].clone().reshape(
            [bs * vn, vl, -1])
        src_value_tmp = torch.matmul(dct_m[:dct_n].unsqueeze(dim=0), src_value_tmp).reshape(
            [bs, vn, dct_n, -1]).transpose(2, 3).reshape(
            [bs, vn, -1])  # [32,40,66*11]

        idx = list(range(-self.kernel_size, 0, 1)) + [-1] * output_n
        outputs = []

        key_tmp = self.convK(src_key_tmp / 1000.0)
        for i in range(itera):
            query_tmp = self.convQ(src_query_tmp / 1000.0)
            score_tmp = torch.matmul(query_tmp.transpose(1, 2), key_tmp) + 1e-15
            att_tmp = score_tmp / (torch.sum(score_tmp, dim=2, keepdim=True))
            dct_att_tmp = torch.matmul(att_tmp, src_value_tmp)[:, 0].reshape(
                [bs, -1, dct_n])

            input_gcn = src_tmp[:, idx]
            dct_in_tmp = torch.matmul(dct_m[:dct_n].unsqueeze(dim=0), input_gcn).transpose(1, 2)
            dct_in_tmp = torch.cat([dct_in_tmp, dct_att_tmp], dim=-1)
            dct_out_tmp = self.gcn(dct_in_tmp)
            out_gcn = torch.matmul(idct_m[:, :dct_n].unsqueeze(dim=0),
                                   dct_out_tmp[:, :, :dct_n].transpose(1, 2))
            outputs.append(out_gcn.unsqueeze(2))
            if itera > 1:
                # update key-value query
                out_tmp = out_gcn.clone()[:, 0 - output_n:]
                src_tmp = torch.cat([src_tmp, out_tmp], dim=1)

                vn = 1 - 2 * self.kernel_size - output_n
                vl = self.kernel_size + output_n
                idx_dct = np.expand_dims(np.arange(vl), axis=0) + \
                          np.expand_dims(np.arange(vn, -self.kernel_size - output_n + 1), axis=1)

                src_key_tmp = src_tmp[:, idx_dct[0, :-1]].transpose(1, 2)
                key_new = self.convK(src_key_tmp / 1000.0)
                key_tmp = torch.cat([key_tmp, key_new], dim=2)

                src_dct_tmp = src_tmp[:, idx_dct].clone().reshape(
                    [bs * self.kernel_size, vl, -1])
                src_dct_tmp = torch.matmul(dct_m[:dct_n].unsqueeze(dim=0), src_dct_tmp).reshape(
                    [bs, self.kernel_size, dct_n, -1]).transpose(2, 3).reshape(
                    [bs, self.kernel_size, -1])
                src_value_tmp = torch.cat([src_value_tmp, src_dct_tmp], dim=1)

                src_query_tmp = src_tmp[:, -self.kernel_size:].transpose(1, 2)

        outputs = torch.cat(outputs, dim=2)
        return outputs
Exemple #3
0
    def forward(self, src, output_n=25, input_n=50, itera=1):
        """
        :param src: [batch_size,seq_len,feat_dim]
        :param itera:
        :return:
        """
        subseq_nbr = input_n - output_n - self.kernel_size + 1
        subseq_length = self.kernel_size + output_n

        dct_m, idct_m = util.get_dct_matrix(subseq_length)
        dct_m = torch.from_numpy(dct_m).float().to(device=self.device)
        idct_m = torch.from_numpy(idct_m).float().to(device=self.device)

        outputs = []
        batch_size = src.shape[0]
        src, last_pose = src_reformat(
            src, input_n, substract_last_pose=False)  # [32, 50, 66]
        dt = torch.from_numpy(2.0 - np.exp(-np.arange(10))).to(
            device=self.device).expand(self.in_features,
                                       output_n).transpose(0, 1)
        vel = src[:, -1] - src[:, -2]
        dx = vel.unsqueeze(1) * dt

        new_in = src.clone()
        att_in = src.clone()
        att_in = self.pos_encoding(att_in)
        att_in = self.init_lin(att_in)
        att_out = self.encoder(
            att_in,
            mask=None)  # in : [32, 50, d_model] out : [32, 50, d_model]
        att_out = self.inter_conv(att_out.transpose(1, 2)).transpose(1, 2)
        #att_out = self.inter_lin(att_out)
        #att_out = att_out[:, -self.kernel_size - output_n:]
        att_out = att_out + new_in[:, -self.kernel_size - output_n:]
        att_out = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                               att_out).transpose(1, 2)

        idx = list(range(-self.kernel_size, 0, 1)) + [
            -1
        ] * output_n  # [-10, -9, ..., -1, -1 ..., -1], indexes used for GCN input
        input_gcn = src[:, idx]
        input_gcn[:, self.
                  kernel_size:] = input_gcn[:, self.
                                            kernel_size:] + dx  # add velocity
        dct_in_tmp = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                                  input_gcn).transpose(1, 2)
        dct_in_tmp = torch.cat([dct_in_tmp, att_out], dim=-1)
        dct_out_tmp = self.gcn(dct_in_tmp)
        out_gcn = torch.matmul(idct_m[:, :self.dct_n].unsqueeze(dim=0),
                               dct_out_tmp[:, :, :self.dct_n].transpose(1, 2))

        outputs.append(out_gcn.unsqueeze(2))

        if itera > 1:
            for i in range(itera - 1):
                new_in = torch.cat(
                    [new_in[:, -input_n + output_n:], out_gcn[:, -output_n:]],
                    dim=1)
                att_in = self.init_lin(new_in)
                att_out = self.encoder(
                    att_in, mask=None
                )  # in : [32, 50, d_model] out : [32, 50, d_model]
                att_out = self.inter_conv(att_out.transpose(1, 2)).transpose(
                    1, 2)
                att_out = att_out + new_in[:, -self.kernel_size - output_n:]
                att_out = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                                       att_out).transpose(1, 2)
                input_gcn = new_in[:, idx]
                dct_in_tmp = torch.matmul(dct_m[:self.dct_n].unsqueeze(dim=0),
                                          input_gcn).transpose(1, 2)
                dct_in_tmp = torch.cat([dct_in_tmp, att_out], dim=-1)
                dct_out_tmp = self.gcn(dct_in_tmp)
                out_gcn = torch.matmul(
                    idct_m[:, :self.dct_n].unsqueeze(dim=0),
                    dct_out_tmp[:, :, :self.dct_n].transpose(1, 2))

                outputs.append(out_gcn[:, -output_n:].unsqueeze(2))
        outputs = torch.cat(outputs, dim=1)
        return outputs