def test_whc(self): x_image = torch.arange(12).float().reshape([1, 3, 2, 2, 1]) im2seq = Image2Seq('whc', (3, 2, 2)) x_seq = im2seq(x_image) self.assertEqual(x_seq[0, :, :], x_image[:, 0, 0, 0, :]) self.assertEqual(x_seq[1, :, :], x_image[:, 0, 0, 1, :]) self.assertEqual(x_seq[2, :, :], x_image[:, 0, 1, 0, :]) self.assertEqual(x_seq[3, :, :], x_image[:, 0, 1, 1, :]) self.assertEqual(x_seq[4, :, :], x_image[:, 1, 0, 0, :])
def test_layer_is_well_behaved(self): seq_len = 3 * 4 * 4 image_shape = (3, 4, 4) batch_size = 10 features = 6 x_seq = torch.randn(seq_len, batch_size, features) x_image = torch.randn(batch_size, *image_shape, features) for ar_order in self.ar_orders: self.assert_layer_is_well_behaved(Image2Seq(ar_order, image_shape), x_image) self.assert_layer_is_well_behaved(Seq2Image(ar_order, image_shape), x_seq)
def __init__(self, image_shape, output_dim, num_bits, autoregressive_order='cwh', d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu", kdim=None, vdim=None, attn_bias=True, output_bias=True, checkpoint_blocks=False, in_lambda=lambda x: x, out_lambda=lambda x: x): super(DecoderOnlyTransformer2d, self).__init__() self.image_shape = torch.Size(image_shape) self.autoregressive_order = autoregressive_order self.d_model = d_model self.num_layers = num_layers # Encoding layers self.encode = nn.Sequential( LambdaLayer(in_lambda), nn.Embedding(2**num_bits, d_model), PositionalEncodingImage(image_shape=image_shape, embedding_dim=d_model)) self.im2seq = Image2Seq(autoregressive_order, image_shape) self.seq2im = Seq2Image(autoregressive_order, image_shape) self.ar_shift = AutoregressiveShift(d_model) self.transformer = DecoderOnlyTransformer( d_model=d_model, nhead=nhead, num_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, kdim=kdim, vdim=vdim, attn_bias=attn_bias, checkpoint_blocks=checkpoint_blocks) self.out_linear = nn.Linear(d_model, output_dim, bias=output_bias) self.out_lambda = LambdaLayer(out_lambda) self._reset_parameters()
def test_seq2im2seq(self): seq_len = 3 * 4 * 4 image_shape = (3, 4, 4) batch_size = 10 features = 6 x_seq = torch.randn(seq_len, batch_size, features) for ar_order in self.ar_orders: im2seq = Image2Seq(ar_order, image_shape) seq2im = Seq2Image(ar_order, image_shape) x_image = seq2im(x_seq) x_seq2 = im2seq(x_image) self.assertEqual(x_image.shape, torch.Size([batch_size, *image_shape, features])) self.assertEqual(x_seq2.shape, torch.Size([seq_len, batch_size, features])) self.assertEqual(x_seq, x_seq2)
def test_zigzag_cs(self): x_image = torch.arange(3 * 4 * 4).float().reshape([1, 3, 4, 4, 1]) im2seq = Image2Seq('zigzag_cs', (2, 4, 4)) x_seq = im2seq(x_image) self.assertEqual(x_seq[0, :, :], x_image[:, 0, 0, 0, :]) self.assertEqual(x_seq[1, :, :], x_image[:, 1, 0, 0, :]) self.assertEqual(x_seq[2, :, :], x_image[:, 0, 0, 1, :]) self.assertEqual(x_seq[3, :, :], x_image[:, 1, 0, 1, :]) self.assertEqual(x_seq[4, :, :], x_image[:, 0, 1, 0, :]) self.assertEqual(x_seq[5, :, :], x_image[:, 1, 1, 0, :]) self.assertEqual(x_seq[6, :, :], x_image[:, 0, 2, 0, :]) self.assertEqual(x_seq[7, :, :], x_image[:, 1, 2, 0, :]) self.assertEqual(x_seq[8, :, :], x_image[:, 0, 1, 1, :]) self.assertEqual(x_seq[9, :, :], x_image[:, 1, 1, 1, :]) self.assertEqual(x_seq[10, :, :], x_image[:, 0, 0, 2, :]) self.assertEqual(x_seq[11, :, :], x_image[:, 1, 0, 2, :]) self.assertEqual(x_seq[12, :, :], x_image[:, 0, 0, 3, :]) self.assertEqual(x_seq[13, :, :], x_image[:, 1, 0, 3, :])