Esempio n. 1
0
 def construct(self, category_b, subcategory_b, title_b, abstract_b,
               category_c, subcategory_c, title_c, abstract_c, label):
     predict = self.network(category_b, subcategory_b, title_b, abstract_b,
                            category_c, subcategory_c, title_c, abstract_c)
     dtype = ops.dtype(predict)
     shp = ops.shape(predict)
     loss = self.loss(predict, ops.reshape(ops.cast(label, dtype), shp))
     return loss
Esempio n. 2
0
 def construct(self, x):
     dtype = ops.dtype(x)
     batch_size, n_input_vector, input_vector_dim = ops.shape(x)
     feature = ops.reshape(x, (-1, input_vector_dim))
     attention = ops.reshape(self.dense2(self.dense1(feature)),
                             (batch_size, n_input_vector))
     attention_weight = ops.cast(self.softmax(attention), dtype)
     weighted_input = x * ops.expand_dims(attention_weight, 2)
     return self.sum(weighted_input, 1)
Esempio n. 3
0
def _clip_grad(clip_value, grad):
    """
    Clip gradients.

    Inputs:
        clip_value (float): Specifies how much to clip.
        grad (tuple[Tensor]): Gradients.

    Outputs:
        tuple[Tensor], clipped gradients.
    """
    dt = ops.dtype(grad)
    new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt))
    return new_grad
Esempio n. 4
0
    def construct(self, *inputs):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(*inputs)
        scaling_sens = self.scale_sense
        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
        scaling_sens_filled = ops.ones_like(loss) * ops.cast(
            scaling_sens, ops.dtype(loss))
        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
        # accumulate gradients
        if self.accumulation and self.accumulation_steps > 1:
            accu_succ = self.hyper_map(update_accu_grads, self.accu_grads,
                                       grads)
            loss = ops.depend(loss, accu_succ)
        overflow = self.get_overflow_status(status, grads)
        overflow = self.logical_or(
            self.not_equal(self.accu_overflow, self.zero), overflow)
        accu_overflow = self.select(overflow, self.one, self.zero)

        if self.accumulation:
            succ = False
            self.accu_overflow = accu_overflow
        else:
            self.accu_overflow = self.zero
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)
            grads = self.hyper_map(ops.partial(_grad_scale, scaling_sens),
                                   grads)
            accu_overflow = self.allreduce(accu_overflow)
            overflow = self.less_equal(self.base, accu_overflow)
            accu_grads = ops.depend(self.accu_grads, grads)
            accu_succ = self.hyper_map(reset_accu_grads, accu_grads)
            overflow = ops.depend(overflow, accu_succ)
            overflow = self.reshape(overflow, (()))
            overflow = self.process_loss_scale(overflow)
            if overflow:
                succ = False
            else:
                succ = self.optimizer(grads)

        ret = (loss, overflow, scaling_sens)
        return ops.depend(ret, succ)
Esempio n. 5
0
def _clip_grad(clip_type, clip_value, grad):
    """
    Clip gradients.

    Inputs:
        clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
        clip_value (float): Specifies how much to clip.
        grad (tuple[Tensor]): Gradients.

    Outputs:
        tuple[Tensor], clipped gradients.
    """
    if clip_type != 0 and clip_type != 1:
        return grad
    dt = ops.dtype(grad)
    if clip_type == 0:
        new_grad = ops.clip_by_value(grad, ops.cast(ops.tuple_to_array((-clip_value,)), dt),
                                     ops.cast(ops.tuple_to_array((clip_value,)), dt))
    else:
        new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt))
    return new_grad
Esempio n. 6
0
 def construct(self, predict, target):
     target = ops.cast(target, ops.dtype(predict))
     target = self.ones(predict) * target
     loss = self.loss(predict, target)
     return loss