def __init__(self, output_size, hidden_size): super(AudioWordDecoder, self).__init__() self.output_size = output_size self.hidden_size = hidden_size with self.name_scope(): self.embedding = nn.Embedding(output_size, hidden_size) self.dropout = gluon.nn.Dropout(.15) self.transformer = TransformerDecoder( units=self.hidden_size, num_layers=8, hidden_size=self.hidden_size * 2, max_length=100, num_heads=8, dropout=.15) self.one_step_transformer = TransformerOneStepDecoder( units=self.hidden_size, num_layers=8, hidden_size=self.hidden_size * 2, max_length=100, num_heads=8, dropout=.15, params=self.transformer.collect_params()) self.out = nn.Dense(output_size, in_units=self.hidden_size, flatten=False)
def test_transformer_encoder_decoder(): ctx = mx.current_context() units = 16 encoder = TransformerEncoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, use_residual=True, prefix='transformer_encoder_') encoder.initialize(ctx=ctx) encoder.hybridize() for output_attention in [True, False]: for use_residual in [True, False]: decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, output_attention=output_attention, use_residual=use_residual, prefix='transformer_decoder_') decoder.initialize(ctx=ctx) decoder.hybridize() for batch_size in [4]: for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]: src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx) tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx) src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx) tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx) src_valid_length_npy = src_valid_length_nd.asnumpy() tgt_valid_length_npy = tgt_valid_length_nd.asnumpy() encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd) decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) # Test multi step forwarding output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd, decoder_states, tgt_valid_length_nd) assert(output.shape == (batch_size, tgt_seq_length, units)) output_npy = output.asnumpy() for i in range(batch_size): tgt_v_len = int(tgt_valid_length_npy[i]) if tgt_v_len < tgt_seq_length - 1: assert((output_npy[i, tgt_v_len:, :] == 0).all()) if output_attention: assert(len(additional_outputs) == 3) attention_out = additional_outputs[0][1].asnumpy() assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length)) for i in range(batch_size): mem_v_len = int(src_valid_length_npy[i]) if mem_v_len < src_seq_length - 1: assert((attention_out[i, :, :, mem_v_len:] == 0).all()) if mem_v_len > 0: assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), np.ones(attention_out.shape[1:3])) else: assert(len(additional_outputs) == 0)
class AudioWordDecoder(Block): def __init__(self, output_size, hidden_size): super(AudioWordDecoder, self).__init__() self.output_size = output_size self.hidden_size = hidden_size with self.name_scope(): self.embedding = nn.Embedding(output_size, hidden_size) self.dropout = gluon.nn.Dropout(.15) self.transformer = TransformerDecoder( units=self.hidden_size, num_layers=8, hidden_size=self.hidden_size * 2, max_length=100, num_heads=8, dropout=.15) self.one_step_transformer = TransformerOneStepDecoder( units=self.hidden_size, num_layers=8, hidden_size=self.hidden_size * 2, max_length=100, num_heads=8, dropout=.15, params=self.transformer.collect_params()) self.out = nn.Dense(output_size, in_units=self.hidden_size, flatten=False) def forward(self, inputs, enc_outs, enc_valid_lengths, dec_valid_lengths): dec_input = self.dropout(self.embedding(inputs)) dec_states = self.transformer.init_state_from_encoder( enc_outs, enc_valid_lengths) output, _, _ = self.transformer(dec_input, dec_states, dec_valid_lengths) output = self.out(output) return output
def test_transformer_encoder_decoder(): ctx = mx.Context.default_ctx units = 16 encoder = TransformerEncoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, use_residual=True, prefix='transformer_encoder_') encoder.initialize(ctx=ctx) encoder.hybridize() for output_attention in [True, False]: for use_residual in [True, False]: decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, output_attention=output_attention, use_residual=use_residual, prefix='transformer_decoder_') decoder.initialize(ctx=ctx) decoder.hybridize() for batch_size in [4]: for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]: src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx) tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx) src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx) tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx) src_valid_length_npy = src_valid_length_nd.asnumpy() tgt_valid_length_npy = tgt_valid_length_nd.asnumpy() encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd) decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) # Test multi step forwarding output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd, decoder_states, tgt_valid_length_nd) assert(output.shape == (batch_size, tgt_seq_length, units)) output_npy = output.asnumpy() for i in range(batch_size): tgt_v_len = int(tgt_valid_length_npy[i]) if tgt_v_len < tgt_seq_length - 1: assert((output_npy[i, tgt_v_len:, :] == 0).all()) if output_attention: assert(len(additional_outputs) == 3) attention_out = additional_outputs[0][1].asnumpy() assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length)) for i in range(batch_size): mem_v_len = int(src_valid_length_npy[i]) if mem_v_len < src_seq_length - 1: assert((attention_out[i, :, :, mem_v_len:] == 0).all()) if mem_v_len > 0: assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), np.ones(attention_out.shape[1:3])) else: assert(len(additional_outputs) == 0)
def __init__(self, vocabulary, emb_dim, latent_distrib='vmf', num_units=512, hidden_size=512, num_heads=4, n_latent=256, max_sent_len=64, transformer_layers=6, label_smoothing_epsilon=0.0, kappa=100.0, batch_size=16, kld=0.1, wd_temp=0.01, ctx=mx.cpu(), prefix=None, params=None): super(ARTransformerVAE, self).__init__(prefix=prefix, params=params) self.kld_wt = kld self.n_latent = n_latent self.model_ctx = ctx self.max_sent_len = max_sent_len self.vocabulary = vocabulary self.batch_size = batch_size self.wd_embed_dim = emb_dim self.vocab_size = len(vocabulary.idx_to_token) self.latent_distrib = latent_distrib self.num_units = num_units self.hidden_size = hidden_size self.num_heads = num_heads self.transformer_layers = transformer_layers self.label_smoothing_epsilon = label_smoothing_epsilon self.kappa = kappa with self.name_scope(): if latent_distrib == 'logistic_gaussian': self.latent_dist = LogisticGaussianLatentDistribution(n_latent, ctx, dr=0.0) elif latent_distrib == 'vmf': self.latent_dist = HyperSphericalLatentDistribution( n_latent, kappa=kappa, dr=0.0, ctx=self.model_ctx) elif latent_distrib == 'gaussian': self.latent_dist = GaussianLatentDistribution(n_latent, ctx, dr=0.0) elif latent_distrib == 'gaussian_unitvar': self.latent_dist = GaussianUnitVarLatentDistribution(n_latent, ctx, dr=0.0) else: raise Exception( "Invalid distribution ==> {}".format(latent_distrib)) self.embedding = nn.Embedding(self.vocab_size, self.wd_embed_dim) self.encoder = TransformerEncoder(self.wd_embed_dim, self.num_units, hidden_size=hidden_size, num_heads=num_heads, n_layers=transformer_layers, n_latent=n_latent, sent_size=max_sent_len, batch_size=batch_size, ctx=ctx) #self.decoder = TransformerDecoder(units=num_units, hidden_size=hidden_size, # num_layers=transformer_layers, n_latent=n_latent, max_length = max_sent_len, # tx = ctx) self.decoder = TransformerDecoder(num_layers=transformer_layers, num_heads=num_heads, max_length=max_sent_len, units=self.num_units, hidden_size=hidden_size, dropout=0.1, scaled=True, use_residual=True, weight_initializer=None, bias_initializer=None, prefix='transformer_' + 'dec_', params=params) self.inv_embed = InverseEmbed(batch_size, max_sent_len, self.wd_embed_dim, temp=wd_temp, ctx=self.model_ctx, params=self.embedding.params) self.ce_loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss( axis=-1, from_logits=True) self.embedding.initialize(mx.init.Xavier(magnitude=2.34), ctx=ctx) if self.vocabulary.embedding: self.embedding.weight.set_data( self.vocabulary.embedding.idx_to_vec)
class ARTransformerVAE(Block): def __init__(self, vocabulary, emb_dim, latent_distrib='vmf', num_units=512, hidden_size=512, num_heads=4, n_latent=256, max_sent_len=64, transformer_layers=6, label_smoothing_epsilon=0.0, kappa=100.0, batch_size=16, kld=0.1, wd_temp=0.01, ctx=mx.cpu(), prefix=None, params=None): super(ARTransformerVAE, self).__init__(prefix=prefix, params=params) self.kld_wt = kld self.n_latent = n_latent self.model_ctx = ctx self.max_sent_len = max_sent_len self.vocabulary = vocabulary self.batch_size = batch_size self.wd_embed_dim = emb_dim self.vocab_size = len(vocabulary.idx_to_token) self.latent_distrib = latent_distrib self.num_units = num_units self.hidden_size = hidden_size self.num_heads = num_heads self.transformer_layers = transformer_layers self.label_smoothing_epsilon = label_smoothing_epsilon self.kappa = kappa with self.name_scope(): if latent_distrib == 'logistic_gaussian': self.latent_dist = LogisticGaussianLatentDistribution(n_latent, ctx, dr=0.0) elif latent_distrib == 'vmf': self.latent_dist = HyperSphericalLatentDistribution( n_latent, kappa=kappa, dr=0.0, ctx=self.model_ctx) elif latent_distrib == 'gaussian': self.latent_dist = GaussianLatentDistribution(n_latent, ctx, dr=0.0) elif latent_distrib == 'gaussian_unitvar': self.latent_dist = GaussianUnitVarLatentDistribution(n_latent, ctx, dr=0.0) else: raise Exception( "Invalid distribution ==> {}".format(latent_distrib)) self.embedding = nn.Embedding(self.vocab_size, self.wd_embed_dim) self.encoder = TransformerEncoder(self.wd_embed_dim, self.num_units, hidden_size=hidden_size, num_heads=num_heads, n_layers=transformer_layers, n_latent=n_latent, sent_size=max_sent_len, batch_size=batch_size, ctx=ctx) #self.decoder = TransformerDecoder(units=num_units, hidden_size=hidden_size, # num_layers=transformer_layers, n_latent=n_latent, max_length = max_sent_len, # tx = ctx) self.decoder = TransformerDecoder(num_layers=transformer_layers, num_heads=num_heads, max_length=max_sent_len, units=self.num_units, hidden_size=hidden_size, dropout=0.1, scaled=True, use_residual=True, weight_initializer=None, bias_initializer=None, prefix='transformer_' + 'dec_', params=params) self.inv_embed = InverseEmbed(batch_size, max_sent_len, self.wd_embed_dim, temp=wd_temp, ctx=self.model_ctx, params=self.embedding.params) self.ce_loss_fn = mx.gluon.loss.SoftmaxCrossEntropyLoss( axis=-1, from_logits=True) self.embedding.initialize(mx.init.Xavier(magnitude=2.34), ctx=ctx) if self.vocabulary.embedding: self.embedding.weight.set_data( self.vocabulary.embedding.idx_to_vec) def decode_seq(self, inputs, states, valid_length=None): outputs, states, additional_outputs = self.decoder.decode_seq( inputs=self.embedding(inputs), states=states, valid_length=valid_length) outputs = self.inv_embed(outputs) return outputs, states, additional_outputs def decode_step(self, step_input, states): step_output, states, step_additional_outputs =\ self.decoder(self.embedding(step_input), states) step_output = self.inv_embed(step_output) return step_output, states, step_additional_outputs def encode(self, inputs): embedded = self.embedding(toks) enc = self.encoder(embedded) z, KL = self.latent_dist(enc, self.batch_size) return z, KL def forward(self, toks): z, KL = self.encode(toks) decoder_states = self.decoder.init_state_from_encoder(z) outputs, _, _ = self.decoder_seq(toks, decoder_states) #y = self.decoder(z) #prob_logits = self.inv_embed(y) #log_prob = mx.nd.log_softmax(prob_logits) #recon_loss = self.ce_loss_fn(log_prob, toks) #kl_loss = (KL * self.kld_wt) #loss = recon_loss + kl_loss #return loss, recon_loss, kl_loss, log_prob return outputs