示例#1
0
def gumbel_softmax(logits, temperature, hard, axis=-1, eps=1e-20):
    uniform_samples = ops.UniformReal()(logits.shape)
    gumbels = -ops.log(-ops.log(uniform_samples + eps) + eps) # ~Gumbel(0, 1)
    gumbels = (logits + gumbels) / temperature
    y_soft = ops.Softmax(axis)(gumbels)

    if hard:
        # Straight through
        index = y_soft.argmax(axis)
        y_hard = ops.OneHot(axis)(index, y_soft.shape[axis], ops.scalar_to_array(1.0), ops.scalar_to_array(0.0))
        ret = ops.stop_gradient(y_hard - y_soft) + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret
示例#2
0
    def construct(self, weights, biases, labels, inputs):
        _check_label_dtype(self.dtype(labels), self.cls_name)

        logits, labels = self._compute_sampled_logits(
            weights=weights,
            biases=biases,
            labels=labels,
            inputs=inputs,
            num_true=self.num_true,
            sampled_values=self.sampled_values,
            subtract_log_q=True)

        labels = ops.stop_gradient(labels)
        x = self._softmax_cross_entropy(logits, labels)
        return x
示例#3
0
    def construct(self, logits):
        uniform_samples = self.uniform(logits.shape)
        gumbels = -ops.log(-ops.log(uniform_samples))  # ~Gumbel(0, 1)
        gumbels = (logits + gumbels) / self.temperature
        y_soft = self.softmax(gumbels)

        if self.hard:
            # Straight through
            index = y_soft.argmax(self.axis)
            y_hard = ops.OneHot(self.axis)(index, y_soft.shape[self.axis],
                                           self.on_value, self.off_value)
            ret = ops.stop_gradient(y_hard - y_soft) + y_soft
        else:
            # Reparametrization trick.
            ret = y_soft
        return ret
示例#4
0
 def fn_aux(*args):
     outputs = fn(*args)
     no_grad_outputs = ()
     for out in outputs[1:]:
         no_grad_outputs += (stop_gradient(out), )
     return outputs[0], no_grad_outputs
示例#5
0
    def _compute_sampled_logits(self,
                                weights,
                                biases,
                                labels,
                                inputs,
                                num_true=1,
                                sampled_values=None,
                                subtract_log_q=True):
        """Helper function for SampledSoftmaxLoss functions.

        Computes sampled output training logits and labels suitable

        Note: In the case where num_true > 1, we assign to each target class
        the target probability 1 / num_true so that the target probabilities
        sum to 1 per-example.

        Args:
            weights (Tensor): Tensor of shape `[num_classes, dim]`.
            biases (Tensor): Tensor of shape `[num_classes]`.
            labels (Tensor): Tensor of shape `[batch_size, num_true]`. The target classes.
            inputs (Tensor): Tensor of shape `[batch_size, dim]`.  The forward
                activations of the input network.
            num_true (int): The number of target classes per training example.
            sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
                `sampled_expected_count`) returned by a `UniformCandidateSampler` function.
            subtract_log_q: A `bool`.  whether to subtract the log expected count of
                the labels in the sample to get the logits of the true labels.
                Default is True.
        Returns:
            out_logits: `Tensor` object with shape
                `[batch_size, num_true + num_sampled]`
            out_labels: A Tensor object with the same shape as `out_logits`.
        """

        if not labels.dtype == mstype.int32:
            labels = self.cast(labels, mstype.int32)
        labels = self.reshape(labels, (-1, num_true))
        labels_flat = self.reshape(labels, (-1, ))

        # Sample the negative labels.
        #   sampled shape: [num_sampled] tensor
        #   true_expected_count shape is [batch_size, 1] tensor
        #   sampled_expected_count shape is [num_sampled] tensor
        if sampled_values is None:
            labels = self.cast(labels, mstype.int64)
            sampled_values = self.sampler(labels)

        (sampled, true_expected_count, sampled_expected_count) = sampled_values
        sampled = ops.stop_gradient(sampled)
        true_expected_count = ops.stop_gradient(true_expected_count)
        sampled_expected_count = ops.stop_gradient(sampled_expected_count)

        if not sampled.dtype == mstype.int32:
            sampled = self.cast(sampled, mstype.int32)
        all_ids = self.concat_dim0((labels_flat, sampled))
        all_w = self.gather_v2(weights, all_ids, 0)

        n_true = self.shape(labels_flat)[0]
        n_sampled = self.shape(sampled)[0]
        n_dim = self.shape(all_w)[1]

        # true_w shape is [batch_size * num_true, dim]
        true_w = self.slice_op(all_w, [0, 0], [n_true, n_dim])
        sampled_w = self.slice_op(all_w, [n_true, 0], [n_sampled, n_dim])
        sampled_logits = self.matmul(inputs, sampled_w)

        all_b = self.gather_v2(biases, all_ids, 0)
        true_b = self.slice_op(all_b, [0], [n_true])
        sampled_b = self.slice_op(all_b, [n_true], [n_sampled])

        # inputs shape is [batch_size, dim]
        # true_w shape is [batch_size * num_true, dim]
        # row_wise_dots is [batch_size, num_true, dim]
        new_true_w_shape = (-1, num_true, n_dim)
        row_wise_dots = self.mul(self.expand_dims(inputs, 1),
                                 self.reshape(true_w, new_true_w_shape))

        # We want the row-wise dot plus biases which yields a
        # [batch_size, num_true] tensor of true_logits.
        dots_as_matrix = self.reshape(row_wise_dots, (-1, n_dim))
        true_logits = self.reshape(self.reduce_sum(dots_as_matrix, 1),
                                   (-1, num_true))
        true_b = self.reshape(true_b, (-1, num_true))
        true_logits += true_b
        sampled_logits += sampled_b

        if self.remove_accidental_hits:
            acc_hits = self.compute_accidental_hits(labels, sampled)
            acc_indices, acc_ids, acc_weights = acc_hits
            acc_weights_length = acc_weights.shape[0]

            # This is how SparseToDense expects the indices.
            acc_indices_2d = self.reshape(acc_indices[:acc_weights_length],
                                          (-1, 1))
            acc_ids_2d_int32 = self.reshape(acc_ids[:acc_weights_length],
                                            (-1, 1))
            sparse_indices = self.concat_dim1(
                (acc_indices_2d, acc_ids_2d_int32))
            #sparse_indices = self.cast(sparse_indices, mstype.int32)
            # Create sampled_logits_shape = [batch_size, num_sampled]
            sampled_logits_shape = sampled_logits.shape

            # if self.dtype(sampled_logits) != self.dtype(acc_weights):
            #     acc_weights = self.cast(acc_weights, self.dtype(sampled_logits))

            # sampled_logits += self.sparse_to_dense(
            #    sparse_indices,
            #    acc_weights,
            #    sampled_logits_shape)

        if subtract_log_q:
            # Subtract log of Q(l), prior probability that l appears in sampled.
            true_logits -= self.log(true_expected_count)
            sampled_logits -= self.log(sampled_expected_count)

        # Construct output logits and labels. The true labels/logits start at col 0.
        out_logits = self.concat_dim1((true_logits, sampled_logits))

        # true_logits is a float tensor, ones_like(true_logits) is a float
        # tensor of ones. We then divide by num_true to ensure the per-example
        # labels sum to 1.0, i.e. form a proper probability distribution.
        out_labels = self.concat_dim1((self.ones_like(true_logits) / num_true,
                                       self.zeros_like(sampled_logits)))
        return out_logits, out_labels