def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) return ops.depend( loss, self.hyper_map(ops.partial(_sum_op), self.grad_sum, grads))
def __init__(self, threshold, value): super().__init__() self.threshold = threshold self.value = value self.greater = ops.Greater() self.fill = ops.Fill() self.select = ops.Select()
def construct(self, lr, hr, width_mult, tea_width_mult): weights = self.weights loss = self.network(lr, hr, width_mult, tea_width_mult) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(lr, hr, width_mult, tea_width_mult, sens) self.optimizer(grads) return loss
def __init__(self, padding: Union[int, Tuple[int, int]], value): super().__init__() if isinstance(padding, int): self.padding = (padding, padding) else: self.padding = padding self.value = value self.concat = ops.Concat(-1) self.fill = ops.Fill()
def construct(self, realA, realB): """ Define TrainOneStepCell. """ d_loss = self.loss_netD(realA, realB) g_loss = self.loss_netG(realA, realB) d_sens = ops.Fill()(ops.DType()(d_loss), ops.Shape()(d_loss), self.sens) d_grads = self.grad(self.loss_netD, self.weights_D)(realA, realB, d_sens) d_res = ops.depend(d_loss, self.optimizerD(d_grads)) g_sens = ops.Fill()(ops.DType()(g_loss), ops.Shape()(g_loss), self.sens) g_grads = self.grad(self.loss_netG, self.weights_G)(realA, realB, g_sens) g_res = ops.depend(g_loss, self.optimizerG(g_grads)) return d_res, g_res
def construct(self, img_A, img_B, fake_A, fake_B): weights = self.weights ld = self.D(img_A, img_B, fake_A, fake_B) sens_d = ops.Fill()(ops.DType()(ld), ops.Shape()(ld), self.sens) grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B, sens_d) if self.reducer_flag: # apply grad reducer on grads grads_d = self.grad_reducer(grads_d) return ops.depend(ld, self.optimizer(grads_d))
def construct(self, img_A, img_B): weights = self.weights fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B) sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens) grads_g = self.grad(self.net, weights)(img_A, img_B, sens) if self.reducer_flag: # apply grad reducer on grads grads_g = self.grad_reducer(grads_g) return fake_A, fake_B, ops.depend( lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib
def __init__(self, alpha=2, beta=4): super(FocalLoss, self).__init__() self.alpha = alpha self.beta = beta self.pow = ops.Pow() self.log = ops.Log() self.select = ops.Select() self.equal = ops.Equal() self.less = ops.Less() self.cast = ops.Cast() self.fill = ops.Fill() self.dtype = ops.DType() self.shape = ops.Shape() self.reduce_sum = ops.ReduceSum()
def construct(self, *inputs): """Defines the computation performed.""" weights = self.weights loss = self.network(*inputs) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) if self.accumulation and self.accumulation_steps > 1: accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) loss = ops.depend(loss, accu_succ) if self.accumulation: succ = False else: grads = self.grad_reducer(grads) accu_grads = ops.depend(self.accu_grads, grads) accu_succ = self.hyper_map(reset_accu_grads, accu_grads) loss = ops.depend(loss, accu_succ) succ = self.optimizer(grads) return ops.depend(loss, succ)