Exemplo n.º 1
0
    def construct(self,
                  input_ids,
                  input_mask,
                  token_type_id,
                  label_ids,
                  sens=None):
        """Bert Finetune"""

        weights = self.weights
        init = False
        loss = self.network(input_ids, input_mask, token_type_id, label_ids)
        if sens is None:
            scaling_sens = self.loss_scale
        else:
            scaling_sens = sens

        if not self.gpu_target:
            init = self.alloc_status()
            clear_before_grad = self.clear_before_grad(init)
            loss = P.depend(loss, init)
            self.depend_parameter_use(clear_before_grad, scaling_sens)
        grads = self.grad(self.network,
                          weights)(input_ids, input_mask, token_type_id,
                                   label_ids,
                                   self.cast(scaling_sens, ts.float32))
        grads = self.hyper_map(P.partial(grad_scale, scaling_sens), grads)
        grads = self.hyper_map(
            P.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
            grads)
        if not self.gpu_target:
            flag = self.get_status(init)
            flag_sum = self.reduce_sum(init, (0, ))
            grads = P.depend(grads, flag)
            flag_sum = P.depend(flag_sum, flag)
        else:
            flag_sum = self.hyper_map(P.partial(_grad_overflow), grads)
            flag_sum = self.addn(flag_sum)
            flag_sum = self.reshape(flag_sum, (()))

        cond = self.less_equal(self.base, flag_sum)
        overflow = cond
        if sens is None:
            overflow = self.loss_scaling_manager(self.loss_scale, cond)
        if overflow:
            succ = False
        else:
            succ = self.optimizer(grads)
        ret = (loss, cond)
        return P.depend(ret, succ)
Exemplo n.º 2
0
 def construct(self, img_A, img_B, fake_A, fake_B):
     weights = self.weights
     ld = self.D(img_A, img_B, fake_A, fake_B)
     sens_d = Fill()(DType()(ld), Shape()(ld), self.sens)
     grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B,
                                          sens_d)
     return depend(ld, self.optimizer(grads_d))
Exemplo n.º 3
0
 def construct(self, img_A, img_B):
     weights = self.weights
     fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B)
     sens = Fill()(DType()(lg), Shape()(lg), self.sens)
     grads_g = self.grad(self.net, weights)(img_A, img_B, sens)
     return fake_A, fake_B, depend(
         lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib
Exemplo n.º 4
0
 def construct(self, ids, wts, labels):
     weights = self.weights
     loss = self.network(ids, wts, labels)
     sens = Fill()(DType()(loss), Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(ids, wts, labels, sens)
     if self.reducer_flag:
         # apply grad reducer on grads
         grads = self.grad_reducer(grads)
     return depend(loss, self.optimizer(grads))
Exemplo n.º 5
0
 def construct(self,
               input_ids,
               input_mask,
               token_type_id,
               start_position,
               end_position,
               unique_id,
               is_impossible,
               sens=None):
     """BertSquad"""
     weights = self.weights
     init = self.alloc_status()
     loss = self.network(input_ids, input_mask, token_type_id,
                         start_position, end_position, unique_id,
                         is_impossible)
     if sens is None:
         scaling_sens = self.loss_scale
     else:
         scaling_sens = sens
     grads = self.grad(self.network,
                       weights)(input_ids, input_mask, token_type_id,
                                start_position, end_position, unique_id,
                                is_impossible,
                                self.cast(scaling_sens, ts.float32))
     clear_before_grad = self.clear_before_grad(init)
     loss = P.depend(loss, init)
     self.depend_parameter_use(clear_before_grad, scaling_sens)
     grads = self.hyper_map(P.partial(grad_scale, scaling_sens), grads)
     grads = self.hyper_map(
         P.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
         grads)
     flag = self.get_status(init)
     flag_sum = self.reduce_sum(init, (0, ))
     cond = self.less_equal(self.base, flag_sum)
     P.depend(grads, flag)
     P.depend(flag, flag_sum)
     overflow = cond
     if sens is None:
         overflow = self.loss_scaling_manager(self.loss_scale, cond)
     if overflow:
         succ = False
     else:
         succ = self.optimizer(grads)
     ret = (loss, cond)
     return P.depend(ret, succ)
Exemplo n.º 6
0
 def construct(self, *args):
     weights = self.weights
     loss = self.network(*args)
     sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(*args, sens)
     return P.depend(loss, self.optimizer(grads))