def gumbel_softmax(logits, temperature, hard, axis=-1, eps=1e-20): uniform_samples = ops.UniformReal()(logits.shape) gumbels = -ops.log(-ops.log(uniform_samples + eps) + eps) # ~Gumbel(0, 1) gumbels = (logits + gumbels) / temperature y_soft = ops.Softmax(axis)(gumbels) if hard: # Straight through index = y_soft.argmax(axis) y_hard = ops.OneHot(axis)(index, y_soft.shape[axis], ops.scalar_to_array(1.0), ops.scalar_to_array(0.0)) ret = ops.stop_gradient(y_hard - y_soft) + y_soft else: # Reparametrization trick. ret = y_soft return ret
def construct(self, weights, biases, labels, inputs): _check_label_dtype(self.dtype(labels), self.cls_name) logits, labels = self._compute_sampled_logits( weights=weights, biases=biases, labels=labels, inputs=inputs, num_true=self.num_true, sampled_values=self.sampled_values, subtract_log_q=True) labels = ops.stop_gradient(labels) x = self._softmax_cross_entropy(logits, labels) return x
def construct(self, logits): uniform_samples = self.uniform(logits.shape) gumbels = -ops.log(-ops.log(uniform_samples)) # ~Gumbel(0, 1) gumbels = (logits + gumbels) / self.temperature y_soft = self.softmax(gumbels) if self.hard: # Straight through index = y_soft.argmax(self.axis) y_hard = ops.OneHot(self.axis)(index, y_soft.shape[self.axis], self.on_value, self.off_value) ret = ops.stop_gradient(y_hard - y_soft) + y_soft else: # Reparametrization trick. ret = y_soft return ret
def fn_aux(*args): outputs = fn(*args) no_grad_outputs = () for out in outputs[1:]: no_grad_outputs += (stop_gradient(out), ) return outputs[0], no_grad_outputs
def _compute_sampled_logits(self, weights, biases, labels, inputs, num_true=1, sampled_values=None, subtract_log_q=True): """Helper function for SampledSoftmaxLoss functions. Computes sampled output training logits and labels suitable Note: In the case where num_true > 1, we assign to each target class the target probability 1 / num_true so that the target probabilities sum to 1 per-example. Args: weights (Tensor): Tensor of shape `[num_classes, dim]`. biases (Tensor): Tensor of shape `[num_classes]`. labels (Tensor): Tensor of shape `[batch_size, num_true]`. The target classes. inputs (Tensor): Tensor of shape `[batch_size, dim]`. The forward activations of the input network. num_true (int): The number of target classes per training example. sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, `sampled_expected_count`) returned by a `UniformCandidateSampler` function. subtract_log_q: A `bool`. whether to subtract the log expected count of the labels in the sample to get the logits of the true labels. Default is True. Returns: out_logits: `Tensor` object with shape `[batch_size, num_true + num_sampled]` out_labels: A Tensor object with the same shape as `out_logits`. """ if not labels.dtype == mstype.int32: labels = self.cast(labels, mstype.int32) labels = self.reshape(labels, (-1, num_true)) labels_flat = self.reshape(labels, (-1, )) # Sample the negative labels. # sampled shape: [num_sampled] tensor # true_expected_count shape is [batch_size, 1] tensor # sampled_expected_count shape is [num_sampled] tensor if sampled_values is None: labels = self.cast(labels, mstype.int64) sampled_values = self.sampler(labels) (sampled, true_expected_count, sampled_expected_count) = sampled_values sampled = ops.stop_gradient(sampled) true_expected_count = ops.stop_gradient(true_expected_count) sampled_expected_count = ops.stop_gradient(sampled_expected_count) if not sampled.dtype == mstype.int32: sampled = self.cast(sampled, mstype.int32) all_ids = self.concat_dim0((labels_flat, sampled)) all_w = self.gather_v2(weights, all_ids, 0) n_true = self.shape(labels_flat)[0] n_sampled = self.shape(sampled)[0] n_dim = self.shape(all_w)[1] # true_w shape is [batch_size * num_true, dim] true_w = self.slice_op(all_w, [0, 0], [n_true, n_dim]) sampled_w = self.slice_op(all_w, [n_true, 0], [n_sampled, n_dim]) sampled_logits = self.matmul(inputs, sampled_w) all_b = self.gather_v2(biases, all_ids, 0) true_b = self.slice_op(all_b, [0], [n_true]) sampled_b = self.slice_op(all_b, [n_true], [n_sampled]) # inputs shape is [batch_size, dim] # true_w shape is [batch_size * num_true, dim] # row_wise_dots is [batch_size, num_true, dim] new_true_w_shape = (-1, num_true, n_dim) row_wise_dots = self.mul(self.expand_dims(inputs, 1), self.reshape(true_w, new_true_w_shape)) # We want the row-wise dot plus biases which yields a # [batch_size, num_true] tensor of true_logits. dots_as_matrix = self.reshape(row_wise_dots, (-1, n_dim)) true_logits = self.reshape(self.reduce_sum(dots_as_matrix, 1), (-1, num_true)) true_b = self.reshape(true_b, (-1, num_true)) true_logits += true_b sampled_logits += sampled_b if self.remove_accidental_hits: acc_hits = self.compute_accidental_hits(labels, sampled) acc_indices, acc_ids, acc_weights = acc_hits acc_weights_length = acc_weights.shape[0] # This is how SparseToDense expects the indices. acc_indices_2d = self.reshape(acc_indices[:acc_weights_length], (-1, 1)) acc_ids_2d_int32 = self.reshape(acc_ids[:acc_weights_length], (-1, 1)) sparse_indices = self.concat_dim1( (acc_indices_2d, acc_ids_2d_int32)) #sparse_indices = self.cast(sparse_indices, mstype.int32) # Create sampled_logits_shape = [batch_size, num_sampled] sampled_logits_shape = sampled_logits.shape # if self.dtype(sampled_logits) != self.dtype(acc_weights): # acc_weights = self.cast(acc_weights, self.dtype(sampled_logits)) # sampled_logits += self.sparse_to_dense( # sparse_indices, # acc_weights, # sampled_logits_shape) if subtract_log_q: # Subtract log of Q(l), prior probability that l appears in sampled. true_logits -= self.log(true_expected_count) sampled_logits -= self.log(sampled_expected_count) # Construct output logits and labels. The true labels/logits start at col 0. out_logits = self.concat_dim1((true_logits, sampled_logits)) # true_logits is a float tensor, ones_like(true_logits) is a float # tensor of ones. We then divide by num_true to ensure the per-example # labels sum to 1.0, i.e. form a proper probability distribution. out_labels = self.concat_dim1((self.ones_like(true_logits) / num_true, self.zeros_like(sampled_logits))) return out_logits, out_labels