def sampled_softmax(num_classes, num_samples, in_dim, inputs, weight, bias, sampled_values, remove_accidental_hits=True): """ Sampled softmax via importance sampling. This under-estimates the full softmax and is only used for training. """ # inputs = (n, in_dim) sample, prob_sample, prob_target = sampled_values # (num_samples, ) sample = S.var('sample', shape=(num_samples,), dtype='float32') # (n, ) label = S.var('label') label = S.reshape(label, shape=(-1,), name="label_reshape") # (num_samples+n, ) sample_label = S.concat(sample, label, dim=0) # lookup weights and biases # (num_samples+n, dim) sample_target_w = S.sparse.Embedding(data=sample_label, weight=weight, input_dim=num_classes, output_dim=in_dim, sparse_grad=True) # (num_samples+n, 1) sample_target_b = S.sparse.Embedding(data=sample_label, weight=bias, input_dim=num_classes, output_dim=1, sparse_grad=True) # (num_samples, dim) sample_w = S.slice(sample_target_w, begin=(0, 0), end=(num_samples, None)) target_w = S.slice(sample_target_w, begin=(num_samples, 0), end=(None, None)) sample_b = S.slice(sample_target_b, begin=(0, 0), end=(num_samples, None)) target_b = S.slice(sample_target_b, begin=(num_samples, 0), end=(None, None)) # target # (n, 1) true_pred = S.sum(target_w * inputs, axis=1, keepdims=True) + target_b # samples # (n, num_samples) sample_b = S.reshape(sample_b, (-1,)) sample_pred = S.FullyConnected(inputs, weight=sample_w, bias=sample_b, num_hidden=num_samples) # remove accidental hits if remove_accidental_hits: label_v = S.reshape(label, (-1, 1)) sample_v = S.reshape(sample, (1, -1)) neg = S.broadcast_equal(label_v, sample_v) * -1e37 sample_pred = sample_pred + neg prob_sample = S.reshape(prob_sample, shape=(1, num_samples)) p_target = true_pred - S.log(prob_target) p_sample = S.broadcast_sub(sample_pred, S.log(prob_sample)) # return logits and new_labels # (n, 1+num_samples) logits = S.concat(p_target, p_sample, dim=1) new_targets = S.zeros_like(label) return logits, new_targets
def siamese(): labels = mxs.Variable(name='label') flat_a, flat_b = siamese_simp_net() distance = mxs.sqrt(mxs.sum(mxs.square(flat_a - flat_b), axis=1)) cl1 = labels * mxs.square(distance) cl2 = (1 - labels) * mxs.square(mxs.maximum(1 - distance, 0)) contrastive_loss = mxs.MakeLoss(mxs.mean(cl1 + cl2)) distance_output = mxs.BlockGrad(distance, name='distance') flat_a_output = mxs.BlockGrad(flat_a) flat_b_output = mxs.BlockGrad(flat_b) sym = mx.sym.Group( [contrastive_loss, distance_output, flat_a_output, flat_b_output]) mod = mx.mod.Module(symbol=sym, context=mx.gpu(), data_names=['data_a', 'data_b'], label_names=['label']) return mod
arg_params, aux_params = mod.get_params() d2 = mxs.var('data_b') emb2 = embedder(d2, '_b') infer_shape(emb2, data_b=(1, 3, 32, 32), data=None) emb2_arguments = emb2.list_arguments() emb2_arguments.pop(0) emb2_auxiliary = emb2.list_auxiliary_states() shared_buffer = {} for i, name in enumerate(emb1_arguments): shared_buffer[emb1_arguments[i]] = arg_params[name].as_in_context(mx.сpu(0)) shared_buffer[emb2_arguments[i]] = arg_params[name].as_in_context(mx.сpu(0)) for i, name in enumerate(emb1_auxiliary): shared_buffer[emb1_auxiliary[i]] = aux_params[name].as_in_context(mx.cpu(0)) shared_buffer[emb2_auxiliary[i]] = aux_params[name].as_in_context(mx.cpu(0)) distance = mxs.sqrt(mxs.sum(mxs.pow(emb1 - emb2, 2), axis=1)) infer_shape(distance, data_a=(1, 3, 32, 32), data_b=(1, 3, 32, 32), data=None) siamese_exe = distance.simple_bind(mx.cpu(0), data_a=(1, 3, 32, 32), data_b=(1, 3, 32, 32), shared_buffer=shared_buffer) distance_out = siamese_exe.forward(False, data_a=im_tr_1, data_b=im_tr_2) print(distance_out)
def _kl(P, n_classes): return symbols.sum(symbols.log(P), axis=1) / float(n_classes)