def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias, sampled_values, remove_accidental_hits=True): """ Sampled softmax via importance sampling. This under-estimates the full softmax and is only used for training. """ # inputs = (n, in_dim) sample, prob_sample, prob_target = sampled_values # (num_samples, ) sample = S.var('sample', shape=(num_samples,), dtype='float32') # (n, ) label = S.var('label') label = S.reshape(label, shape=(-1,), name="label_reshape") # (num_samples+n, ) sample_label = S.concat(sample, label, dim=0) # lookup weights and biases # (num_samples+n, dim) sample_target_w = S.sparse.Embedding(data=sample_label, weight=weight, input_dim=num_classes, output_dim=in_dim, sparse_grad=True) # (num_samples+n, 1) sample_target_b = S.sparse.Embedding(data=sample_label, weight=bias, input_dim=num_classes, output_dim=1, sparse_grad=True) # (num_samples, dim) sample_w = S.slice(sample_target_w, begin=(0, 0), end=(num_samples, None)) target_w = S.slice(sample_target_w, begin=(num_samples, 0), end=(None, None)) sample_b = S.slice(sample_target_b, begin=(0, 0), end=(num_samples, None)) target_b = S.slice(sample_target_b, begin=(num_samples, 0), end=(None, None)) # target # (n, 1) true_pred = S.sum(target_w * inputs, axis=1, keepdims=True) + target_b # samples # (n, num_samples) sample_b = S.reshape(sample_b, (-1,)) sample_pred = S.FullyConnected(inputs, weight=sample_w, bias=sample_b, num_hidden=num_samples) # remove accidental hits if remove_accidental_hits: label_v = S.reshape(label, (-1, 1)) sample_v = S.reshape(sample, (1, -1)) neg = S.broadcast_equal(label_v, sample_v) * -1e37 sample_pred = sample_pred + neg prob_sample = S.reshape(prob_sample, shape=(1, num_samples)) p_target = true_pred - S.log(prob_target) p_sample = S.broadcast_sub(sample_pred, S.log(prob_sample)) # return logits and new_labels # (n, 1+num_samples) logits = S.concat(p_target, p_sample, dim=1) new_targets = S.zeros_like(label) return logits, new_targets
def rnn(bptt, vocab_size, num_embed, nhid, num_layers, dropout, num_proj, batch_size): """ word embedding + LSTM Projected """ state_names = [] data = S.var('data') weight = S.var("encoder_weight", stype='row_sparse') embed = S.sparse.Embedding(data=data, weight=weight, input_dim=vocab_size, output_dim=num_embed, name='embed', sparse_grad=True) states = [] outputs = S.Dropout(embed, p=dropout) for i in range(num_layers): prefix = 'lstmp%d_' % i init_h = S.var(prefix + 'init_h', shape=(batch_size, num_proj), init=mx.init.Zero()) init_c = S.var(prefix + 'init_c', shape=(batch_size, nhid), init=mx.init.Zero()) state_names += [prefix + 'init_h', prefix + 'init_c'] lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj) outputs, next_states = lstmp.unroll(bptt, outputs, begin_state=[init_h, init_c], \ layout='NTC', merge_outputs=True) outputs = S.Dropout(outputs, p=dropout) states += [S.stop_gradient(s) for s in next_states] outputs = S.reshape(outputs, shape=(-1, num_proj)) trainable_lstm_args = [] for arg in outputs.list_arguments(): if 'lstmp' in arg and 'init' not in arg: trainable_lstm_args.append(arg) return outputs, states, trainable_lstm_args, state_names
def __init__(self, ntokens, rescale_loss, bptt, emsize, nhid, nlayers, dropout, num_proj, batch_size, k): out = rnn(bptt, ntokens, emsize, nhid, nlayers, dropout, num_proj, batch_size) rnn_out, self.last_states, self.lstm_args, self.state_names = out # decoder weight and bias decoder_w = S.var("decoder_weight", stype='row_sparse') decoder_b = S.var("decoder_bias", shape=(ntokens, 1), stype='row_sparse') # sampled softmax for training sample = S.var('sample', shape=(k,)) prob_sample = S.var("prob_sample", shape=(k,)) prob_target = S.var("prob_target") self.sample_names = ['sample', 'prob_sample', 'prob_target'] logits, new_targets = sampled_softmax(ntokens, k, num_proj, rnn_out, decoder_w, decoder_b, [sample, prob_sample, prob_target]) self.train_loss = cross_entropy_loss(logits, new_targets, rescale_loss=rescale_loss) # full softmax for testing eval_logits = S.FullyConnected(data=rnn_out, weight=decoder_w, num_hidden=ntokens, name='decode_fc', bias=decoder_b) label = S.Variable('label') label = S.reshape(label, shape=(-1,)) self.eval_loss = cross_entropy_loss(eval_logits, label)
def cross_entropy_loss(inputs, labels, rescale_loss=1): """ cross entropy loss with a mask """ criterion = mx.gluon.loss.SoftmaxCrossEntropyLoss(weight=rescale_loss) loss = criterion(inputs, labels) mask = S.var('mask') loss = loss * S.reshape(mask, shape=(-1,)) return S.make_loss(loss.mean())
def _attention_pooling(source, scores): # source: (batch_size, seq_len, encoder_num_hidden) # scores: (batch_size, seq_len, 1) probs = symbol.softmax(scores, axis=1) output = symbol.batch_dot(source, probs, transpose_a=True) return symbol.reshape(output, shape=(0, 0))