Example #1
0
 def construct(self, *inputs):
     weights = self.weights
     loss = self.network(*inputs)
     sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(*inputs, sens)
     return ops.depend(
         loss, self.hyper_map(ops.partial(_sum_op), self.grad_sum, grads))
Example #2
0
 def __init__(self, threshold, value):
     super().__init__()
     self.threshold = threshold
     self.value = value
     self.greater = ops.Greater()
     self.fill = ops.Fill()
     self.select = ops.Select()
Example #3
0
    def construct(self, lr, hr, width_mult, tea_width_mult):
        weights = self.weights
        loss = self.network(lr, hr, width_mult, tea_width_mult)

        sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
        grads = self.grad(self.network, weights)(lr, hr, width_mult, tea_width_mult, sens)
        self.optimizer(grads)
        return loss
Example #4
0
 def __init__(self, padding: Union[int, Tuple[int, int]], value):
     super().__init__()
     if isinstance(padding, int):
         self.padding = (padding, padding)
     else:
         self.padding = padding
     self.value = value
     self.concat = ops.Concat(-1)
     self.fill = ops.Fill()
Example #5
0
    def construct(self, realA, realB):
        """
            Define TrainOneStepCell.
        """
        d_loss = self.loss_netD(realA, realB)
        g_loss = self.loss_netG(realA, realB)

        d_sens = ops.Fill()(ops.DType()(d_loss), ops.Shape()(d_loss),
                            self.sens)
        d_grads = self.grad(self.loss_netD, self.weights_D)(realA, realB,
                                                            d_sens)
        d_res = ops.depend(d_loss, self.optimizerD(d_grads))

        g_sens = ops.Fill()(ops.DType()(g_loss), ops.Shape()(g_loss),
                            self.sens)
        g_grads = self.grad(self.loss_netG, self.weights_G)(realA, realB,
                                                            g_sens)
        g_res = ops.depend(g_loss, self.optimizerG(g_grads))
        return d_res, g_res
Example #6
0
 def construct(self, img_A, img_B, fake_A, fake_B):
     weights = self.weights
     ld = self.D(img_A, img_B, fake_A, fake_B)
     sens_d = ops.Fill()(ops.DType()(ld), ops.Shape()(ld), self.sens)
     grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B,
                                          sens_d)
     if self.reducer_flag:
         # apply grad reducer on grads
         grads_d = self.grad_reducer(grads_d)
     return ops.depend(ld, self.optimizer(grads_d))
Example #7
0
    def construct(self, img_A, img_B):
        weights = self.weights
        fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B)
        sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens)
        grads_g = self.grad(self.net, weights)(img_A, img_B, sens)
        if self.reducer_flag:
            # apply grad reducer on grads
            grads_g = self.grad_reducer(grads_g)

        return fake_A, fake_B, ops.depend(
            lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib
Example #8
0
 def __init__(self, alpha=2, beta=4):
     super(FocalLoss, self).__init__()
     self.alpha = alpha
     self.beta = beta
     self.pow = ops.Pow()
     self.log = ops.Log()
     self.select = ops.Select()
     self.equal = ops.Equal()
     self.less = ops.Less()
     self.cast = ops.Cast()
     self.fill = ops.Fill()
     self.dtype = ops.DType()
     self.shape = ops.Shape()
     self.reduce_sum = ops.ReduceSum()
Example #9
0
 def construct(self, *inputs):
     """Defines the computation performed."""
     weights = self.weights
     loss = self.network(*inputs)
     sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
     grads = self.grad(self.network, weights)(*inputs, sens)
     if self.accumulation and self.accumulation_steps > 1:
         accu_succ = self.hyper_map(update_accu_grads, self.accu_grads,
                                    grads)
         loss = ops.depend(loss, accu_succ)
     if self.accumulation:
         succ = False
     else:
         grads = self.grad_reducer(grads)
         accu_grads = ops.depend(self.accu_grads, grads)
         accu_succ = self.hyper_map(reset_accu_grads, accu_grads)
         loss = ops.depend(loss, accu_succ)
         succ = self.optimizer(grads)
     return ops.depend(loss, succ)