Beispiel #1
0
def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias,
                    sampled_values, remove_accidental_hits=True):
        """ Sampled softmax via importance sampling.
            This under-estimates the full softmax and is only used for training.
        """
        # inputs = (n, in_dim)
        sample, prob_sample, prob_target = sampled_values

        # (num_samples, )
        sample = S.var('sample', shape=(num_samples,), dtype='float32')
        # (n, )
        label = S.var('label')
        label = S.reshape(label, shape=(-1,), name="label_reshape")
        # (num_samples+n, )
        sample_label = S.concat(sample, label, dim=0)
        # lookup weights and biases
        # (num_samples+n, dim)
        sample_target_w = S.sparse.Embedding(data=sample_label, weight=weight,
                                             input_dim=num_classes, output_dim=in_dim,
                                             sparse_grad=True)
        # (num_samples+n, 1)
        sample_target_b = S.sparse.Embedding(data=sample_label, weight=bias,
                                             input_dim=num_classes, output_dim=1,
                                             sparse_grad=True)
        # (num_samples, dim)
        sample_w = S.slice(sample_target_w, begin=(0, 0), end=(num_samples, None))
        target_w = S.slice(sample_target_w, begin=(num_samples, 0), end=(None, None))
        sample_b = S.slice(sample_target_b, begin=(0, 0), end=(num_samples, None))
        target_b = S.slice(sample_target_b, begin=(num_samples, 0), end=(None, None))

        # target
        # (n, 1)
        true_pred = S.sum(target_w * inputs, axis=1, keepdims=True) + target_b
        # samples
        # (n, num_samples)
        sample_b = S.reshape(sample_b, (-1,))
        sample_pred = S.FullyConnected(inputs, weight=sample_w, bias=sample_b,
                                       num_hidden=num_samples)

        # remove accidental hits
        if remove_accidental_hits:
            label_v = S.reshape(label, (-1, 1))
            sample_v = S.reshape(sample, (1, -1))
            neg = S.broadcast_equal(label_v, sample_v) * -1e37
            sample_pred = sample_pred + neg

        prob_sample = S.reshape(prob_sample, shape=(1, num_samples))
        p_target = true_pred - S.log(prob_target)
        p_sample = S.broadcast_sub(sample_pred, S.log(prob_sample))

        # return logits and new_labels
        # (n, 1+num_samples)
        logits = S.concat(p_target, p_sample, dim=1)
        new_targets = S.zeros_like(label)
        return logits, new_targets
Beispiel #2
0
def rnn(bptt, vocab_size, num_embed, nhid, num_layers, dropout, num_proj, batch_size):
    """ word embedding + LSTM Projected """
    state_names = []
    data = S.var('data')
    weight = S.var("encoder_weight", stype='row_sparse')
    embed = S.sparse.Embedding(data=data, weight=weight, input_dim=vocab_size,
                               output_dim=num_embed, name='embed', sparse_grad=True)
    states = []
    outputs = S.Dropout(embed, p=dropout)
    for i in range(num_layers):
        prefix = 'lstmp%d_' % i
        init_h = S.var(prefix + 'init_h', shape=(batch_size, num_proj), init=mx.init.Zero())
        init_c = S.var(prefix + 'init_c', shape=(batch_size, nhid), init=mx.init.Zero())
        state_names += [prefix + 'init_h', prefix + 'init_c']
        lstmp = mx.gluon.contrib.rnn.LSTMPCell(nhid, num_proj)
        outputs, next_states = lstmp.unroll(bptt, outputs, begin_state=[init_h, init_c], \
                                            layout='NTC', merge_outputs=True)
        outputs = S.Dropout(outputs, p=dropout)
        states += [S.stop_gradient(s) for s in next_states]
    outputs = S.reshape(outputs, shape=(-1, num_proj))

    trainable_lstm_args = []
    for arg in outputs.list_arguments():
        if 'lstmp' in arg and 'init' not in arg:
            trainable_lstm_args.append(arg)
    return outputs, states, trainable_lstm_args, state_names
Beispiel #3
0
    def __init__(self, ntokens, rescale_loss, bptt, emsize,
                 nhid, nlayers, dropout, num_proj, batch_size, k):
        out = rnn(bptt, ntokens, emsize, nhid, nlayers,
                  dropout, num_proj, batch_size)
        rnn_out, self.last_states, self.lstm_args, self.state_names = out
        # decoder weight and bias
        decoder_w = S.var("decoder_weight", stype='row_sparse')
        decoder_b = S.var("decoder_bias", shape=(ntokens, 1), stype='row_sparse')

        # sampled softmax for training
        sample = S.var('sample', shape=(k,))
        prob_sample = S.var("prob_sample", shape=(k,))
        prob_target = S.var("prob_target")
        self.sample_names = ['sample', 'prob_sample', 'prob_target']
        logits, new_targets = sampled_softmax(ntokens, k, num_proj,
                                              rnn_out, decoder_w, decoder_b,
                                              [sample, prob_sample, prob_target])
        self.train_loss = cross_entropy_loss(logits, new_targets, rescale_loss=rescale_loss)

        # full softmax for testing
        eval_logits = S.FullyConnected(data=rnn_out, weight=decoder_w,
                                       num_hidden=ntokens, name='decode_fc', bias=decoder_b)
        label = S.Variable('label')
        label = S.reshape(label, shape=(-1,))
        self.eval_loss = cross_entropy_loss(eval_logits, label)
Beispiel #4
0
def cross_entropy_loss(inputs, labels, rescale_loss=1):
    """ cross entropy loss with a mask """
    criterion = mx.gluon.loss.SoftmaxCrossEntropyLoss(weight=rescale_loss)
    loss = criterion(inputs, labels)
    mask = S.var('mask')
    loss = loss * S.reshape(mask, shape=(-1,))
    return S.make_loss(loss.mean())
def _attention_pooling(source, scores):
    # source: (batch_size, seq_len, encoder_num_hidden)
    # scores: (batch_size, seq_len, 1)
    probs = symbol.softmax(scores, axis=1)
    output = symbol.batch_dot(source, probs, transpose_a=True)
    return symbol.reshape(output, shape=(0, 0))