Esempio n. 1
0
def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias,
                    sampled_values, remove_accidental_hits=True):
        """ Sampled softmax via importance sampling.
            This under-estimates the full softmax and is only used for training.
        """
        # inputs = (n, in_dim)
        sample, prob_sample, prob_target = sampled_values

        # (num_samples, )
        sample = S.var('sample', shape=(num_samples,), dtype='float32')
        # (n, )
        label = S.var('label')
        label = S.reshape(label, shape=(-1,), name="label_reshape")
        # (num_samples+n, )
        sample_label = S.concat(sample, label, dim=0)
        # lookup weights and biases
        # (num_samples+n, dim)
        sample_target_w = S.sparse.Embedding(data=sample_label, weight=weight,
                                             input_dim=num_classes, output_dim=in_dim,
                                             sparse_grad=True)
        # (num_samples+n, 1)
        sample_target_b = S.sparse.Embedding(data=sample_label, weight=bias,
                                             input_dim=num_classes, output_dim=1,
                                             sparse_grad=True)
        # (num_samples, dim)
        sample_w = S.slice(sample_target_w, begin=(0, 0), end=(num_samples, None))
        target_w = S.slice(sample_target_w, begin=(num_samples, 0), end=(None, None))
        sample_b = S.slice(sample_target_b, begin=(0, 0), end=(num_samples, None))
        target_b = S.slice(sample_target_b, begin=(num_samples, 0), end=(None, None))

        # target
        # (n, 1)
        true_pred = S.sum(target_w * inputs, axis=1, keepdims=True) + target_b
        # samples
        # (n, num_samples)
        sample_b = S.reshape(sample_b, (-1,))
        sample_pred = S.FullyConnected(inputs, weight=sample_w, bias=sample_b,
                                       num_hidden=num_samples)

        # remove accidental hits
        if remove_accidental_hits:
            label_v = S.reshape(label, (-1, 1))
            sample_v = S.reshape(sample, (1, -1))
            neg = S.broadcast_equal(label_v, sample_v) * -1e37
            sample_pred = sample_pred + neg

        prob_sample = S.reshape(prob_sample, shape=(1, num_samples))
        p_target = true_pred - S.log(prob_target)
        p_sample = S.broadcast_sub(sample_pred, S.log(prob_sample))

        # return logits and new_labels
        # (n, 1+num_samples)
        logits = S.concat(p_target, p_sample, dim=1)
        new_targets = S.zeros_like(label)
        return logits, new_targets
Esempio n. 2
0
def siamese():
    labels = mxs.Variable(name='label')
    flat_a, flat_b = siamese_simp_net()
    distance = mxs.sqrt(mxs.sum(mxs.square(flat_a - flat_b), axis=1))
    cl1 = labels * mxs.square(distance)
    cl2 = (1 - labels) * mxs.square(mxs.maximum(1 - distance, 0))
    contrastive_loss = mxs.MakeLoss(mxs.mean(cl1 + cl2))
    distance_output = mxs.BlockGrad(distance, name='distance')
    flat_a_output = mxs.BlockGrad(flat_a)
    flat_b_output = mxs.BlockGrad(flat_b)
    sym = mx.sym.Group(
        [contrastive_loss, distance_output, flat_a_output, flat_b_output])
    mod = mx.mod.Module(symbol=sym,
                        context=mx.gpu(),
                        data_names=['data_a', 'data_b'],
                        label_names=['label'])
    return mod
Esempio n. 3
0
arg_params, aux_params = mod.get_params()

d2 = mxs.var('data_b')
emb2 = embedder(d2, '_b')
infer_shape(emb2, data_b=(1, 3, 32, 32), data=None)
emb2_arguments = emb2.list_arguments()
emb2_arguments.pop(0)
emb2_auxiliary = emb2.list_auxiliary_states()


shared_buffer = {}
for i, name in enumerate(emb1_arguments):
    shared_buffer[emb1_arguments[i]] = arg_params[name].as_in_context(mx.сpu(0))
    shared_buffer[emb2_arguments[i]] = arg_params[name].as_in_context(mx.сpu(0))

for i, name in enumerate(emb1_auxiliary):
    shared_buffer[emb1_auxiliary[i]] = aux_params[name].as_in_context(mx.cpu(0))
    shared_buffer[emb2_auxiliary[i]] = aux_params[name].as_in_context(mx.cpu(0))


distance = mxs.sqrt(mxs.sum(mxs.pow(emb1 - emb2, 2), axis=1))
infer_shape(distance, data_a=(1, 3, 32, 32), data_b=(1, 3, 32, 32), data=None)

siamese_exe = distance.simple_bind(mx.cpu(0), data_a=(1, 3, 32, 32), data_b=(1, 3, 32, 32), shared_buffer=shared_buffer)
distance_out = siamese_exe.forward(False, data_a=im_tr_1, data_b=im_tr_2)
print(distance_out)




Esempio n. 4
0
 def _kl(P, n_classes):
     return symbols.sum(symbols.log(P), axis=1) / float(n_classes)