def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  next_sentence_labels,
                  masked_lm_positions,
                  masked_lm_ids,
                  masked_lm_weights):
        """Defines the computation performed."""
        weights = self.weights

        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            next_sentence_labels,
                            masked_lm_positions,
                            masked_lm_ids,
                            masked_lm_weights)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 next_sentence_labels,
                                                 masked_lm_positions,
                                                 masked_lm_ids,
                                                 masked_lm_weights,
                                                 self.cast(ops.tuple_to_array((self.sens,)),
                                                           mstype.float32))
        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        if self.reducer_flag:
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)
        succ = self.optimizer(grads)
        return ops.depend(loss, succ)
Beispiel #2
0
 def construct(self, *inputs):
     weights = self.weights
     loss = self.network(*inputs)
     sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(*inputs, sens)
     return ops.depend(
         loss, self.hyper_map(ops.partial(_sum_op), self.grad_sum, grads))
Beispiel #3
0
    def construct(self,
                  input_ids,
                  token_type_id,
                  pad_mask,
                  sens=None):
        """construct BertPoetryCell"""

        weights = self.weights
        loss = self.network(input_ids,
                            token_type_id,
                            pad_mask)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

        grads = self.grad(self.network, weights)(input_ids,
                                                 token_type_id,
                                                 pad_mask,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))
        grads = self.hyper_map(ops.partial(grad_scale, scaling_sens), grads)
        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        if self.reducer_flag:
            grads = self.grad_reducer(grads)
        cond = self.get_overflow_status(status, grads)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)
        ret = (loss, cond)
        return ops.depend(ret, succ)
Beispiel #4
0
    def construct(self, *inputs):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(*inputs)
        scaling_sens = self.scale_sense
        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
        scaling_sens_filled = ops.ones_like(loss) * ops.cast(
            scaling_sens, ops.dtype(loss))
        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
        # accumulate gradients
        if self.accumulation and self.accumulation_steps > 1:
            accu_succ = self.hyper_map(update_accu_grads, self.accu_grads,
                                       grads)
            loss = ops.depend(loss, accu_succ)
        overflow = self.get_overflow_status(status, grads)
        overflow = self.logical_or(
            self.not_equal(self.accu_overflow, self.zero), overflow)
        accu_overflow = self.select(overflow, self.one, self.zero)

        if self.accumulation:
            succ = False
            self.accu_overflow = accu_overflow
        else:
            self.accu_overflow = self.zero
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)
            grads = self.hyper_map(ops.partial(_grad_scale, scaling_sens),
                                   grads)
            accu_overflow = self.allreduce(accu_overflow)
            overflow = self.less_equal(self.base, accu_overflow)
            accu_grads = ops.depend(self.accu_grads, grads)
            accu_succ = self.hyper_map(reset_accu_grads, accu_grads)
            overflow = ops.depend(overflow, accu_succ)
            overflow = self.reshape(overflow, (()))
            overflow = self.process_loss_scale(overflow)
            if overflow:
                succ = False
            else:
                succ = self.optimizer(grads)

        ret = (loss, overflow, scaling_sens)
        return ops.depend(ret, succ)
Beispiel #5
0
 def construct(self):
     seccess = self.hyper_map(ops.partial(_clear_op), self.grad_sum,
                              self.zeros)
     return seccess
Beispiel #6
0
 def construct(self, scale, grads):
     grads = self.hyper_map(ops.partial(grad_scale, scale), grads)
     return grads
Beispiel #7
0
 def construct(self, grads):
     grads = self.hyper_map(ops.partial(clip_grad, self.clip_norm), grads)
     return grads