예제 #1
0
    def _forward_encoder(self, seq, adjs):
        x = seq

        for i in range(self.n_layers):
            x = getattr(self, f"encoder_pre_{i}")(x)
            x_seq_half = x[:, :x.shape[1] // 2, :]
            x_adj_half = x[:, x.shape[1] // 2:, :]
            x_adj_half = getattr(self, f"encoder_0_{i}")(x_adj_half, i, adjs)
            x = torch.cat([x_seq_half, x_adj_half], 1)
            x = getattr(self, f"encoder_1_{i}")(x, i, adjs)
            x = getattr(self, f"encoder_post_{i}")(x)
            # logger.debug(f"Encoder layer: {i}, input shape: {x.shape}")

        # ...
        x = x.maxpool()

        if self.bottleneck_size > 0:
            pad_amount = padding_amount(x, 2048)  # 4 x 512
            if pad_amount:
                x = F.pad(x, (0, pad_amount))
            x = reshape_internal_dim(x, 1, 512)
            x = self.conv_in(x)
            # x = self.linear_in(x.transpose(1, 2).contiguous())

        return x
예제 #2
0
    def forward_bak(self, seq, adjs):
        x = seq

        for i in range(self.n_layers):
            x = getattr(self, f"encoder_pre_{i}")(x)
            x_seq = x[:, :x.shape[1] // 2, :]
            x_seq = getattr(self, f"encoder_seq_{i}")(x_seq, i, adjs)
            x_adj = x[:, x.shape[1] // 2:, :]
            x_adj = getattr(self, f"encoder_adj_{i}")(x_adj, i, adjs)
            x = torch.cat([x_seq, x_adj], 1)
            x = getattr(self, f"encoder_post_{i}")(x)
            # logger.debug(f"Encoder layer: {i}, input shape: {x.shape}")

        if self.bottleneck_size == 0:
            x = x.max(2, keepdim=True)[0]
            x = self.linear_in(x.squeeze()).unsqueeze(-1)
        else:
            raise NotImplementedError
            pad_amount = padding_amount(x, 2048)  # 4 x 512
            if pad_amount:
                x = F.pad(x, (0, pad_amount))
            x = reshape_internal_dim(x, 1, 512)
            x = self.conv_in(x)

        return x
예제 #3
0
    def forward(self, seq, adjs):
        x = seq

        # Encode
        for i in range(self.n_layers):
            x = getattr(self, f"encoder_pre_{i}")(x)
            x_seq = x[:, :x.shape[1] // 2, :]
            x_adj = x[:, x.shape[1] // 2:, :]
            x_seq = getattr(self, f"encoder_downsample_seq_{i}")(x_seq, i,
                                                                 adjs)
            x_adj = getattr(self, f"encoder_downsample_adj_{i}")(x_adj, i,
                                                                 adjs)
            x = torch.cat([x_seq, x_adj], 1)
            x = getattr(self, f"encoder_post_{i}")(x)
            logger.debug(f"{i}, {x.shape}")

        # === Linear ===
        if self.bottleneck_size > 0:

            pad_amount = padding_amount(x, 2048)  # 4 x 512
            if pad_amount:
                x = F.pad(x, (0, pad_amount))

            n_features = x.shape[1]
            x = reshape_internal_dim(x, 1, 512)

            x = self.conv_in(x)
            # x = self.linear_in(x.transpose(1, 2).contiguous())

            assert 0.9 < (np.prod(x.shape) /
                          (seq.shape[2] / 64 * self.bottleneck_size)) <= 1.1, (
                              x.shape[1:],
                              seq.shape[2] / 64 * self.bottleneck_size,
                          )

            # x = self.linear_out(x).transpose(2, 1).contiguous()
            x = self.conv_out(x)

            x = reshape_internal_dim(x, 1, n_features)

            if pad_amount:
                x = x[:, :, :-pad_amount]

        # Decode
        for i in range(self.n_layers - 1, -1, -1):
            # x = getattr(self, f'decoder_pre_{i}')(x)
            x_seq = x[:, :x.shape[1] // 2, :]
            x_adj = x[:, x.shape[1] // 2:, :]
            x_seq = getattr(self, f"decoder_upsample_seq_{i}")(x_seq, i, adjs)
            x_adj = getattr(self, f"decoder_upsample_adj_{i}")(x_adj, i, adjs)
            x = torch.cat([x_seq, x_adj], 1)
            x = getattr(self, f"decoder_post_{i}")(x)
            logger.debug(f"{i}, {x.shape}")

        return x
    def forward(self, seq, adjs):
        x = seq

        # Encode
        for i in range(self.n_layers):
            x = getattr(self, f"encoder_pre_{i}")(x)
            x = getattr(self, f"encoder_downsample_{i}")(x, i, adjs)
            x = getattr(self, f"encoder_post_{i}")(x)

        # Linear
        if self.bottleneck_size > 0:
            pad_amount = padding_amount(x, 2048)  # 4 x 512
            if pad_amount:
                x = F.pad(x, (0, pad_amount))
            n_features = x.shape[1]
            x = reshape_internal_dim(x, 1, 512)
            x = self.conv_in(x)
            # x = self.linear_in(x.transpose(1, 2).contiguous())
            assert 0.9 < (np.prod(x.shape) /
                          (seq.shape[2] / 64 * self.bottleneck_size)) <= 1.1, (
                              x.shape[1:],
                              seq.shape[2] / 64 * self.bottleneck_size,
                          )
            # x = self.linear_out(x).transpose(2, 1).contiguous()
            x = self.conv_out(x)
            x = reshape_internal_dim(x, 1, n_features)
            if pad_amount:
                x = x[:, :, :-pad_amount]

        # Decode
        for i in range(self.n_layers - 1, -1, -1):
            x = getattr(self, f"decoder_pre_{i}")(x)
            x = getattr(self, f"decoder_upsample_{i}")(x, i, adjs)
            x = getattr(self, f"decoder_post_{i}")(x)

        return x
    def forward(self, seq, adjs):
        x = seq

        # Encode
        for i in range(self.n_layers):
            x_list = []
            start = 0
            for adj in adjs:
                seq_len = adj[i].shape[1]
                end = start + seq_len
                assert end <= x.shape[2]
                xd = x[:, :, start:end]
                xd = getattr(self, f"encoder_{i}")(xd, seq_len)
                assert xd.shape[2] == adj[i + 1].shape[1]
                x_list.append(xd)
                start = end
            assert end == x.shape[2]
            x = torch.cat(x_list, 2)
            x = getattr(self, f"encoder_post_{i}")(x)
            logger.debug(f"{i}, {x.shape}")

        # Linear
        #         x_list = []
        #         start = 0
        #         for adj in adjs:
        #             seq_len = adj[i + 1]
        #             end = start + seq_len
        #             assert end <= x.shape[2]
        #             xd = x[:, :, start:end]
        #             pad_amount = padding_amount(xd, 2048)
        #             if pad_amount:
        #                 xd = F.pad(xd, (0, pad_amount))
        #             xd = unfold_to(xd, 2048)
        #             xd = self.linear_in(xd)
        # #             xd = self.conv(xd)
        #             assert np.prod(xd.shape) == 64, xd.shape
        # #             xd = self.convt(xd)
        #             xd = self.linear_out(xd)
        #             xd = unfold_from(xd, n_features)
        #             if pad_amount:
        #                 xd = xd[:, :, :-pad_amount]
        #             x_list.append(xd)
        #             start = end
        #         assert end == x.shape[2]
        #         x = torch.cat(x_list, 2)

        # === Linear ===
        if self.bottleneck_size > 0:

            pad_amount = padding_amount(x, 2048)  # 4 x 512
            if pad_amount:
                x = F.pad(x, (0, pad_amount))

            n_features = x.shape[1]
            x = reshape_internal_dim(x, 1, 512)

            x = self.conv_in(x)
            # x = self.linear_in(x.transpose(1, 2).contiguous())

            assert 0.9 < (np.prod(x.shape) / (seq.shape[2] / 64 * self.bottleneck_size)) <= 1.1, (
                x.shape[1:],
                seq.shape[2] / 64 * self.bottleneck_size,
            )

            # x = self.linear_out(x).transpose(2, 1).contiguous()
            x = self.conv_out(x)

            x = reshape_internal_dim(x, 1, n_features)

            if pad_amount:
                x = x[:, :, :-pad_amount]

        # Decode
        for i in range(self.n_layers - 1, -1, -1):
            x_list = []
            start = 0
            for adj in adjs:
                seq_len = adj[i].shape[1]
                conv_seq_len = adj[i + 1].shape[1]
                end = start + conv_seq_len
                assert end <= x.shape[2]
                xd = x[:, :, start:end]
                xd = getattr(self, f"decoder_{i}")(xd, seq_len)
                x_list.append(xd)
                start = end
            assert end == x.shape[2]
            x = torch.cat(x_list, 2)
            x = getattr(self, f"decoder_post_{i}")(x)
            logger.debug(f"{i}, {x.shape}")

        return x