Exemple #1
0
    def construct(self, image, hm, reg_mask, ind, wh, kps, kps_mask, reg,
                  hm_hp, hp_offset, hp_ind, hp_mask):
        """Defines the computation performed."""
        image = self.image(image)
        weights = self.weights
        loss = self.network(image, hm, reg_mask, ind, wh, kps, kps_mask, reg,
                            hm_hp, hp_offset, hp_ind, hp_mask)
        scaling_sens = self.cast(self.loss_scale, mstype.float32) * 2.0 / 2.0
        # alloc status and clear should be right before gradoperation
        init = self.alloc_status()
        init = ops.depend(init, scaling_sens)
        clear_status = self.clear_status(init)
        scaling_sens = ops.depend(scaling_sens, clear_status)
        grads = self.grad(self.network,
                          weights)(image, hm, reg_mask, ind, wh, kps, kps_mask,
                                   reg, hm_hp, hp_offset, hp_ind, hp_mask,
                                   scaling_sens)
        grads = self.grad_reducer(grads)
        grads = self.grad_scale(scaling_sens * self.degree, grads)
        init = ops.depend(init, grads)
        get_status = self.get_status(init)
        init = ops.depend(init, get_status)
        flag_sum = self.reduce_sum(init, (0, ))
        if self.is_distributed:
            flag_reduce = self.allreduce(flag_sum)
            cond = self.less_equal(self.base, flag_reduce)
        else:
            cond = self.less_equal(self.base, flag_sum)

        succ = self.optimizer(grads)
        ret = (loss, cond, scaling_sens)
        return ops.depend(ret, succ)
Exemple #2
0
 def construct(self, *inputs):
     weights = self.weights
     loss = self.network(*inputs)
     sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(*inputs, sens)
     return ops.depend(
         loss, self.hyper_map(ops.partial(_sum_op), self.grad_sum, grads))
    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  next_sentence_labels,
                  masked_lm_positions,
                  masked_lm_ids,
                  masked_lm_weights):
        """Defines the computation performed."""
        weights = self.weights

        loss = self.network(input_ids,
                            input_mask,
                            token_type_id,
                            next_sentence_labels,
                            masked_lm_positions,
                            masked_lm_ids,
                            masked_lm_weights)
        grads = self.grad(self.network, weights)(input_ids,
                                                 input_mask,
                                                 token_type_id,
                                                 next_sentence_labels,
                                                 masked_lm_positions,
                                                 masked_lm_ids,
                                                 masked_lm_weights,
                                                 self.cast(ops.tuple_to_array((self.sens,)),
                                                           mstype.float32))
        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        if self.reducer_flag:
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)
        succ = self.optimizer(grads)
        return ops.depend(loss, succ)
Exemple #4
0
 def construct(self, image, hm, reg_mask, ind, wh, reg):
     """Defines the computation performed."""
     image = self.image(image)
     weights = self.weights
     loss = self.network(image, hm, reg_mask, ind, wh, reg)
     grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, reg)
     succ = self.optimizer(grads)
     ret = loss
     return ops.depend(ret, succ)
Exemple #5
0
    def construct(self, realA, realB):
        """
            Define TrainOneStepCell.
        """
        d_loss = self.loss_netD(realA, realB)
        g_loss = self.loss_netG(realA, realB)

        d_sens = ops.Fill()(ops.DType()(d_loss), ops.Shape()(d_loss),
                            self.sens)
        d_grads = self.grad(self.loss_netD, self.weights_D)(realA, realB,
                                                            d_sens)
        d_res = ops.depend(d_loss, self.optimizerD(d_grads))

        g_sens = ops.Fill()(ops.DType()(g_loss), ops.Shape()(g_loss),
                            self.sens)
        g_grads = self.grad(self.loss_netG, self.weights_G)(realA, realB,
                                                            g_sens)
        g_res = ops.depend(g_loss, self.optimizerG(g_grads))
        return d_res, g_res
Exemple #6
0
 def construct(self, *inputs):
     """Defines the computation performed."""
     weights = self.weights
     loss = self.network(*inputs)
     sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(*inputs, sens)
     if self.accumulation and self.accumulation_steps > 1:
         accu_succ = self.hyper_map(update_accu_grads, self.accu_grads,
                                    grads)
         loss = ops.depend(loss, accu_succ)
     if self.accumulation:
         succ = False
     else:
         grads = self.grad_reducer(grads)
         accu_grads = ops.depend(self.accu_grads, grads)
         accu_succ = self.hyper_map(reset_accu_grads, accu_grads)
         loss = ops.depend(loss, accu_succ)
         succ = self.optimizer(grads)
     return ops.depend(loss, succ)
Exemple #7
0
 def construct(self, img_A, img_B, fake_A, fake_B):
     weights = self.weights
     ld = self.D(img_A, img_B, fake_A, fake_B)
     sens_d = ops.Fill()(ops.DType()(ld), ops.Shape()(ld), self.sens)
     grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B,
                                          sens_d)
     if self.reducer_flag:
         # apply grad reducer on grads
         grads_d = self.grad_reducer(grads_d)
     return ops.depend(ld, self.optimizer(grads_d))
 def construct(self,
               input_ids,
               input_mask,
               token_type_id,
               next_sentence_labels,
               masked_lm_positions,
               masked_lm_ids,
               masked_lm_weights,
               sens=None):
     """Defines the computation performed."""
     weights = self.weights
     loss = self.network(input_ids,
                         input_mask,
                         token_type_id,
                         next_sentence_labels,
                         masked_lm_positions,
                         masked_lm_ids,
                         masked_lm_weights)
     if sens is None:
         scaling_sens = self.loss_scale
     else:
         scaling_sens = sens
     # alloc status and clear should be right before gradoperation
     init = self.alloc_status()
     self.clear_before_grad(init)
     grads = self.grad(self.network, weights)(input_ids,
                                              input_mask,
                                              token_type_id,
                                              next_sentence_labels,
                                              masked_lm_positions,
                                              masked_lm_ids,
                                              masked_lm_weights,
                                              self.cast(scaling_sens,
                                                        mstype.float32))
     # apply grad reducer on grads
     grads = self.grad_reducer(grads)
     grads = self.hyper_map(ops.partial(grad_scale, scaling_sens * self.degree), grads)
     grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
     self.get_status(init)
     flag_sum = self.reduce_sum(init, (0,))
     if self.is_distributed:
         # sum overflow flag over devices
         flag_reduce = self.allreduce(flag_sum)
         cond = self.less_equal(self.base, flag_reduce)
     else:
         cond = self.less_equal(self.base, flag_sum)
     overflow = cond
     if sens is None:
         overflow = self.loss_scaling_manager(self.loss_scale, cond)
     if overflow:
         succ = False
     else:
         succ = self.optimizer(grads)
     ret = (loss, cond, scaling_sens)
     return ops.depend(ret, succ)
Exemple #9
0
    def construct(self, *inputs):
        """Defines the computation performed."""
        weights = self.weights
        loss = self.network(*inputs)
        scaling_sens = self.scale_sense
        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
        scaling_sens_filled = ops.ones_like(loss) * ops.cast(
            scaling_sens, ops.dtype(loss))
        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
        # accumulate gradients
        if self.accumulation and self.accumulation_steps > 1:
            accu_succ = self.hyper_map(update_accu_grads, self.accu_grads,
                                       grads)
            loss = ops.depend(loss, accu_succ)
        overflow = self.get_overflow_status(status, grads)
        overflow = self.logical_or(
            self.not_equal(self.accu_overflow, self.zero), overflow)
        accu_overflow = self.select(overflow, self.one, self.zero)

        if self.accumulation:
            succ = False
            self.accu_overflow = accu_overflow
        else:
            self.accu_overflow = self.zero
            # apply grad reducer on grads
            grads = self.grad_reducer(grads)
            grads = self.hyper_map(ops.partial(_grad_scale, scaling_sens),
                                   grads)
            accu_overflow = self.allreduce(accu_overflow)
            overflow = self.less_equal(self.base, accu_overflow)
            accu_grads = ops.depend(self.accu_grads, grads)
            accu_succ = self.hyper_map(reset_accu_grads, accu_grads)
            overflow = ops.depend(overflow, accu_succ)
            overflow = self.reshape(overflow, (()))
            overflow = self.process_loss_scale(overflow)
            if overflow:
                succ = False
            else:
                succ = self.optimizer(grads)

        ret = (loss, overflow, scaling_sens)
        return ops.depend(ret, succ)
Exemple #10
0
    def construct(self, img_A, img_B):
        weights = self.weights
        fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B)
        sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens)
        grads_g = self.grad(self.net, weights)(img_A, img_B, sens)
        if self.reducer_flag:
            # apply grad reducer on grads
            grads_g = self.grad_reducer(grads_g)

        return fake_A, fake_B, ops.depend(
            lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib
Exemple #11
0
    def construct(self,
                  input_ids,
                  token_type_id,
                  pad_mask,
                  sens=None):
        """construct BertPoetryCell"""

        weights = self.weights
        loss = self.network(input_ids,
                            token_type_id,
                            pad_mask)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens
        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

        grads = self.grad(self.network, weights)(input_ids,
                                                 token_type_id,
                                                 pad_mask,
                                                 self.cast(scaling_sens,
                                                           mstype.float32))
        grads = self.hyper_map(ops.partial(grad_scale, scaling_sens), grads)
        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        if self.reducer_flag:
            grads = self.grad_reducer(grads)
        cond = self.get_overflow_status(status, grads)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)
        ret = (loss, cond)
        return ops.depend(ret, succ)
Exemple #12
0
def _clear_grad_sum(grad_sum, zero):
    """Apply zero to clear grad_sum."""
    success = True
    success = ops.depend(success, ops.assign(grad_sum, zero))
    return success
Exemple #13
0
def _update_accu_grads(accu_grad, grad):
    succ = True
    return ops.depend(succ,
                      ops.assign_add(accu_grad, cast(grad, mstype.float32)))
Exemple #14
0
def _reset_accu_grads(accu_grad):
    succ = True
    return ops.depend(succ, ops.assign(accu_grad, zeroslike(accu_grad)))