def __init__(self, output_size, hidden_size):
        super(AudioWordDecoder, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size

        with self.name_scope():
            self.embedding = nn.Embedding(output_size, hidden_size)
            self.dropout = gluon.nn.Dropout(.15)
            self.transformer = TransformerDecoder(
                units=self.hidden_size,
                num_layers=8,
                hidden_size=self.hidden_size * 2,
                max_length=100,
                num_heads=8,
                dropout=.15)
            self.one_step_transformer = TransformerOneStepDecoder(
                units=self.hidden_size,
                num_layers=8,
                hidden_size=self.hidden_size * 2,
                max_length=100,
                num_heads=8,
                dropout=.15,
                params=self.transformer.collect_params())
            self.out = nn.Dense(output_size,
                                in_units=self.hidden_size,
                                flatten=False)
Example #2
0
def test_transformer_encoder_decoder():
    ctx = mx.current_context()
    units = 16
    encoder = TransformerEncoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10,
                                 dropout=0.0, use_residual=True, prefix='transformer_encoder_')
    encoder.initialize(ctx=ctx)
    encoder.hybridize()
    for output_attention in [True, False]:
        for use_residual in [True, False]:
            decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0,
                                         output_attention=output_attention, use_residual=use_residual, prefix='transformer_decoder_')
            decoder.initialize(ctx=ctx)
            decoder.hybridize()
            for batch_size in [4]:
                for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]:
                    src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx)
                    tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx)
                    src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx)
                    tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx)
                    src_valid_length_npy = src_valid_length_nd.asnumpy()
                    tgt_valid_length_npy = tgt_valid_length_nd.asnumpy()
                    encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd)
                    decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd)

                    # Test multi step forwarding
                    output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd,
                                                                                decoder_states,
                                                                                tgt_valid_length_nd)
                    assert(output.shape == (batch_size, tgt_seq_length, units))
                    output_npy = output.asnumpy()
                    for i in range(batch_size):
                        tgt_v_len = int(tgt_valid_length_npy[i])
                        if tgt_v_len < tgt_seq_length - 1:
                            assert((output_npy[i, tgt_v_len:, :] == 0).all())
                    if output_attention:
                        assert(len(additional_outputs) == 3)
                        attention_out = additional_outputs[0][1].asnumpy()
                        assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length))
                        for i in range(batch_size):
                            mem_v_len = int(src_valid_length_npy[i])
                            if mem_v_len < src_seq_length - 1:
                                assert((attention_out[i, :, :, mem_v_len:] == 0).all())
                            if mem_v_len > 0:
                                assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1),
                                                    np.ones(attention_out.shape[1:3]))
                    else:
                        assert(len(additional_outputs) == 0)
class AudioWordDecoder(Block):
    def __init__(self, output_size, hidden_size):
        super(AudioWordDecoder, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size

        with self.name_scope():
            self.embedding = nn.Embedding(output_size, hidden_size)
            self.dropout = gluon.nn.Dropout(.15)
            self.transformer = TransformerDecoder(
                units=self.hidden_size,
                num_layers=8,
                hidden_size=self.hidden_size * 2,
                max_length=100,
                num_heads=8,
                dropout=.15)
            self.one_step_transformer = TransformerOneStepDecoder(
                units=self.hidden_size,
                num_layers=8,
                hidden_size=self.hidden_size * 2,
                max_length=100,
                num_heads=8,
                dropout=.15,
                params=self.transformer.collect_params())
            self.out = nn.Dense(output_size,
                                in_units=self.hidden_size,
                                flatten=False)

    def forward(self, inputs, enc_outs, enc_valid_lengths, dec_valid_lengths):
        dec_input = self.dropout(self.embedding(inputs))
        dec_states = self.transformer.init_state_from_encoder(
            enc_outs, enc_valid_lengths)
        output, _, _ = self.transformer(dec_input, dec_states,
                                        dec_valid_lengths)
        output = self.out(output)
        return output
def test_transformer_encoder_decoder():
    ctx = mx.Context.default_ctx
    units = 16
    encoder = TransformerEncoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10,
                                 dropout=0.0, use_residual=True, prefix='transformer_encoder_')
    encoder.initialize(ctx=ctx)
    encoder.hybridize()
    for output_attention in [True, False]:
        for use_residual in [True, False]:
            decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0,
                                         output_attention=output_attention, use_residual=use_residual, prefix='transformer_decoder_')
            decoder.initialize(ctx=ctx)
            decoder.hybridize()
            for batch_size in [4]:
                for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]:
                    src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx)
                    tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx)
                    src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx)
                    tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx)
                    src_valid_length_npy = src_valid_length_nd.asnumpy()
                    tgt_valid_length_npy = tgt_valid_length_nd.asnumpy()
                    encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd)
                    decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd)

                    # Test multi step forwarding
                    output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd,
                                                                                decoder_states,
                                                                                tgt_valid_length_nd)
                    assert(output.shape == (batch_size, tgt_seq_length, units))
                    output_npy = output.asnumpy()
                    for i in range(batch_size):
                        tgt_v_len = int(tgt_valid_length_npy[i])
                        if tgt_v_len < tgt_seq_length - 1:
                            assert((output_npy[i, tgt_v_len:, :] == 0).all())
                    if output_attention:
                        assert(len(additional_outputs) == 3)
                        attention_out = additional_outputs[0][1].asnumpy()
                        assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length))
                        for i in range(batch_size):
                            mem_v_len = int(src_valid_length_npy[i])
                            if mem_v_len < src_seq_length - 1:
                                assert((attention_out[i, :, :, mem_v_len:] == 0).all())
                            if mem_v_len > 0:
                                assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1),
                                                    np.ones(attention_out.shape[1:3]))
                    else:
                        assert(len(additional_outputs) == 0)
Example #5
0
    def __init__(self,
                 vocabulary,
                 emb_dim,
                 latent_distrib='vmf',
                 num_units=512,
                 hidden_size=512,
                 num_heads=4,
                 n_latent=256,
                 max_sent_len=64,
                 transformer_layers=6,
                 label_smoothing_epsilon=0.0,
                 kappa=100.0,
                 batch_size=16,
                 kld=0.1,
                 wd_temp=0.01,
                 ctx=mx.cpu(),
                 prefix=None,
                 params=None):

        super(ARTransformerVAE, self).__init__(prefix=prefix, params=params)
        self.kld_wt = kld
        self.n_latent = n_latent
        self.model_ctx = ctx
        self.max_sent_len = max_sent_len
        self.vocabulary = vocabulary
        self.batch_size = batch_size
        self.wd_embed_dim = emb_dim
        self.vocab_size = len(vocabulary.idx_to_token)
        self.latent_distrib = latent_distrib
        self.num_units = num_units
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.transformer_layers = transformer_layers
        self.label_smoothing_epsilon = label_smoothing_epsilon
        self.kappa = kappa

        with self.name_scope():
            if latent_distrib == 'logistic_gaussian':
                self.latent_dist = LogisticGaussianLatentDistribution(n_latent,
                                                                      ctx,
                                                                      dr=0.0)
            elif latent_distrib == 'vmf':
                self.latent_dist = HyperSphericalLatentDistribution(
                    n_latent, kappa=kappa, dr=0.0, ctx=self.model_ctx)
            elif latent_distrib == 'gaussian':
                self.latent_dist = GaussianLatentDistribution(n_latent,
                                                              ctx,
                                                              dr=0.0)
            elif latent_distrib == 'gaussian_unitvar':
                self.latent_dist = GaussianUnitVarLatentDistribution(n_latent,
                                                                     ctx,
                                                                     dr=0.0)
            else:
                raise Exception(
                    "Invalid distribution ==> {}".format(latent_distrib))
            self.embedding = nn.Embedding(self.vocab_size, self.wd_embed_dim)
            self.encoder = TransformerEncoder(self.wd_embed_dim,
                                              self.num_units,
                                              hidden_size=hidden_size,
                                              num_heads=num_heads,
                                              n_layers=transformer_layers,
                                              n_latent=n_latent,
                                              sent_size=max_sent_len,
                                              batch_size=batch_size,
                                              ctx=ctx)

            #self.decoder = TransformerDecoder(units=num_units, hidden_size=hidden_size,
            #                                  num_layers=transformer_layers, n_latent=n_latent, max_length = max_sent_len,
            #                                  tx = ctx)
            self.decoder = TransformerDecoder(num_layers=transformer_layers,
                                              num_heads=num_heads,
                                              max_length=max_sent_len,
                                              units=self.num_units,
                                              hidden_size=hidden_size,
                                              dropout=0.1,
                                              scaled=True,
                                              use_residual=True,
                                              weight_initializer=None,
                                              bias_initializer=None,
                                              prefix='transformer_' + 'dec_',
                                              params=params)
            self.inv_embed = InverseEmbed(batch_size,
                                          max_sent_len,
                                          self.wd_embed_dim,
                                          temp=wd_temp,
                                          ctx=self.model_ctx,
                                          params=self.embedding.params)
            self.ce_loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss(
                axis=-1, from_logits=True)
        self.embedding.initialize(mx.init.Xavier(magnitude=2.34), ctx=ctx)
        if self.vocabulary.embedding:
            self.embedding.weight.set_data(
                self.vocabulary.embedding.idx_to_vec)
Example #6
0
class ARTransformerVAE(Block):
    def __init__(self,
                 vocabulary,
                 emb_dim,
                 latent_distrib='vmf',
                 num_units=512,
                 hidden_size=512,
                 num_heads=4,
                 n_latent=256,
                 max_sent_len=64,
                 transformer_layers=6,
                 label_smoothing_epsilon=0.0,
                 kappa=100.0,
                 batch_size=16,
                 kld=0.1,
                 wd_temp=0.01,
                 ctx=mx.cpu(),
                 prefix=None,
                 params=None):

        super(ARTransformerVAE, self).__init__(prefix=prefix, params=params)
        self.kld_wt = kld
        self.n_latent = n_latent
        self.model_ctx = ctx
        self.max_sent_len = max_sent_len
        self.vocabulary = vocabulary
        self.batch_size = batch_size
        self.wd_embed_dim = emb_dim
        self.vocab_size = len(vocabulary.idx_to_token)
        self.latent_distrib = latent_distrib
        self.num_units = num_units
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.transformer_layers = transformer_layers
        self.label_smoothing_epsilon = label_smoothing_epsilon
        self.kappa = kappa

        with self.name_scope():
            if latent_distrib == 'logistic_gaussian':
                self.latent_dist = LogisticGaussianLatentDistribution(n_latent,
                                                                      ctx,
                                                                      dr=0.0)
            elif latent_distrib == 'vmf':
                self.latent_dist = HyperSphericalLatentDistribution(
                    n_latent, kappa=kappa, dr=0.0, ctx=self.model_ctx)
            elif latent_distrib == 'gaussian':
                self.latent_dist = GaussianLatentDistribution(n_latent,
                                                              ctx,
                                                              dr=0.0)
            elif latent_distrib == 'gaussian_unitvar':
                self.latent_dist = GaussianUnitVarLatentDistribution(n_latent,
                                                                     ctx,
                                                                     dr=0.0)
            else:
                raise Exception(
                    "Invalid distribution ==> {}".format(latent_distrib))
            self.embedding = nn.Embedding(self.vocab_size, self.wd_embed_dim)
            self.encoder = TransformerEncoder(self.wd_embed_dim,
                                              self.num_units,
                                              hidden_size=hidden_size,
                                              num_heads=num_heads,
                                              n_layers=transformer_layers,
                                              n_latent=n_latent,
                                              sent_size=max_sent_len,
                                              batch_size=batch_size,
                                              ctx=ctx)

            #self.decoder = TransformerDecoder(units=num_units, hidden_size=hidden_size,
            #                                  num_layers=transformer_layers, n_latent=n_latent, max_length = max_sent_len,
            #                                  tx = ctx)
            self.decoder = TransformerDecoder(num_layers=transformer_layers,
                                              num_heads=num_heads,
                                              max_length=max_sent_len,
                                              units=self.num_units,
                                              hidden_size=hidden_size,
                                              dropout=0.1,
                                              scaled=True,
                                              use_residual=True,
                                              weight_initializer=None,
                                              bias_initializer=None,
                                              prefix='transformer_' + 'dec_',
                                              params=params)
            self.inv_embed = InverseEmbed(batch_size,
                                          max_sent_len,
                                          self.wd_embed_dim,
                                          temp=wd_temp,
                                          ctx=self.model_ctx,
                                          params=self.embedding.params)
            self.ce_loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss(
                axis=-1, from_logits=True)
        self.embedding.initialize(mx.init.Xavier(magnitude=2.34), ctx=ctx)
        if self.vocabulary.embedding:
            self.embedding.weight.set_data(
                self.vocabulary.embedding.idx_to_vec)

    def decode_seq(self, inputs, states, valid_length=None):
        outputs, states, additional_outputs = self.decoder.decode_seq(
            inputs=self.embedding(inputs),
            states=states,
            valid_length=valid_length)
        outputs = self.inv_embed(outputs)
        return outputs, states, additional_outputs

    def decode_step(self, step_input, states):
        step_output, states, step_additional_outputs =\
            self.decoder(self.embedding(step_input), states)
        step_output = self.inv_embed(step_output)
        return step_output, states, step_additional_outputs

    def encode(self, inputs):
        embedded = self.embedding(toks)
        enc = self.encoder(embedded)
        z, KL = self.latent_dist(enc, self.batch_size)
        return z, KL

    def forward(self, toks):
        z, KL = self.encode(toks)
        decoder_states = self.decoder.init_state_from_encoder(z)
        outputs, _, _ = self.decoder_seq(toks, decoder_states)
        #y = self.decoder(z)
        #prob_logits = self.inv_embed(y)
        #log_prob = mx.nd.log_softmax(prob_logits)
        #recon_loss = self.ce_loss_fn(log_prob, toks)
        #kl_loss = (KL * self.kld_wt)
        #loss = recon_loss + kl_loss
        #return loss, recon_loss, kl_loss, log_prob
        return outputs