Ejemplo n.º 1
0
    def test_whc(self):
        x_image = torch.arange(12).float().reshape([1, 3, 2, 2, 1])

        im2seq = Image2Seq('whc', (3, 2, 2))
        x_seq = im2seq(x_image)

        self.assertEqual(x_seq[0, :, :], x_image[:, 0, 0, 0, :])
        self.assertEqual(x_seq[1, :, :], x_image[:, 0, 0, 1, :])
        self.assertEqual(x_seq[2, :, :], x_image[:, 0, 1, 0, :])
        self.assertEqual(x_seq[3, :, :], x_image[:, 0, 1, 1, :])
        self.assertEqual(x_seq[4, :, :], x_image[:, 1, 0, 0, :])
Ejemplo n.º 2
0
    def test_layer_is_well_behaved(self):
        seq_len = 3 * 4 * 4
        image_shape = (3, 4, 4)
        batch_size = 10
        features = 6
        x_seq = torch.randn(seq_len, batch_size, features)
        x_image = torch.randn(batch_size, *image_shape, features)

        for ar_order in self.ar_orders:
            self.assert_layer_is_well_behaved(Image2Seq(ar_order, image_shape),
                                              x_image)
            self.assert_layer_is_well_behaved(Seq2Image(ar_order, image_shape),
                                              x_seq)
Ejemplo n.º 3
0
    def __init__(self,
                 image_shape,
                 output_dim,
                 num_bits,
                 autoregressive_order='cwh',
                 d_model=512,
                 nhead=8,
                 num_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation="relu",
                 kdim=None,
                 vdim=None,
                 attn_bias=True,
                 output_bias=True,
                 checkpoint_blocks=False,
                 in_lambda=lambda x: x,
                 out_lambda=lambda x: x):
        super(DecoderOnlyTransformer2d, self).__init__()
        self.image_shape = torch.Size(image_shape)
        self.autoregressive_order = autoregressive_order
        self.d_model = d_model
        self.num_layers = num_layers

        # Encoding layers
        self.encode = nn.Sequential(
            LambdaLayer(in_lambda), nn.Embedding(2**num_bits, d_model),
            PositionalEncodingImage(image_shape=image_shape,
                                    embedding_dim=d_model))

        self.im2seq = Image2Seq(autoregressive_order, image_shape)
        self.seq2im = Seq2Image(autoregressive_order, image_shape)
        self.ar_shift = AutoregressiveShift(d_model)

        self.transformer = DecoderOnlyTransformer(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            kdim=kdim,
            vdim=vdim,
            attn_bias=attn_bias,
            checkpoint_blocks=checkpoint_blocks)

        self.out_linear = nn.Linear(d_model, output_dim, bias=output_bias)
        self.out_lambda = LambdaLayer(out_lambda)

        self._reset_parameters()
Ejemplo n.º 4
0
    def test_seq2im2seq(self):
        seq_len = 3 * 4 * 4
        image_shape = (3, 4, 4)
        batch_size = 10
        features = 6
        x_seq = torch.randn(seq_len, batch_size, features)

        for ar_order in self.ar_orders:
            im2seq = Image2Seq(ar_order, image_shape)
            seq2im = Seq2Image(ar_order, image_shape)
            x_image = seq2im(x_seq)
            x_seq2 = im2seq(x_image)

            self.assertEqual(x_image.shape,
                             torch.Size([batch_size, *image_shape, features]))
            self.assertEqual(x_seq2.shape,
                             torch.Size([seq_len, batch_size, features]))
            self.assertEqual(x_seq, x_seq2)
Ejemplo n.º 5
0
    def test_zigzag_cs(self):
        x_image = torch.arange(3 * 4 * 4).float().reshape([1, 3, 4, 4, 1])

        im2seq = Image2Seq('zigzag_cs', (2, 4, 4))
        x_seq = im2seq(x_image)

        self.assertEqual(x_seq[0, :, :], x_image[:, 0, 0, 0, :])
        self.assertEqual(x_seq[1, :, :], x_image[:, 1, 0, 0, :])
        self.assertEqual(x_seq[2, :, :], x_image[:, 0, 0, 1, :])
        self.assertEqual(x_seq[3, :, :], x_image[:, 1, 0, 1, :])
        self.assertEqual(x_seq[4, :, :], x_image[:, 0, 1, 0, :])
        self.assertEqual(x_seq[5, :, :], x_image[:, 1, 1, 0, :])
        self.assertEqual(x_seq[6, :, :], x_image[:, 0, 2, 0, :])
        self.assertEqual(x_seq[7, :, :], x_image[:, 1, 2, 0, :])
        self.assertEqual(x_seq[8, :, :], x_image[:, 0, 1, 1, :])
        self.assertEqual(x_seq[9, :, :], x_image[:, 1, 1, 1, :])
        self.assertEqual(x_seq[10, :, :], x_image[:, 0, 0, 2, :])
        self.assertEqual(x_seq[11, :, :], x_image[:, 1, 0, 2, :])
        self.assertEqual(x_seq[12, :, :], x_image[:, 0, 0, 3, :])
        self.assertEqual(x_seq[13, :, :], x_image[:, 1, 0, 3, :])