def construct(self, input_ids, input_mask, token_type_id, label_ids, sens=None): """Bert Finetune""" weights = self.weights init = False loss = self.network(input_ids, input_mask, token_type_id, label_ids) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens if not self.gpu_target: init = self.alloc_status() clear_before_grad = self.clear_before_grad(init) loss = P.depend(loss, init) self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, label_ids, self.cast(scaling_sens, ts.float32)) grads = self.hyper_map(P.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map( P.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if not self.gpu_target: flag = self.get_status(init) flag_sum = self.reduce_sum(init, (0, )) grads = P.depend(grads, flag) flag_sum = P.depend(flag_sum, flag) else: flag_sum = self.hyper_map(P.partial(_grad_overflow), grads) flag_sum = self.addn(flag_sum) flag_sum = self.reshape(flag_sum, (())) cond = self.less_equal(self.base, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond) return P.depend(ret, succ)
def construct(self, img_A, img_B, fake_A, fake_B): weights = self.weights ld = self.D(img_A, img_B, fake_A, fake_B) sens_d = Fill()(DType()(ld), Shape()(ld), self.sens) grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B, sens_d) return depend(ld, self.optimizer(grads_d))
def construct(self, img_A, img_B): weights = self.weights fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B) sens = Fill()(DType()(lg), Shape()(lg), self.sens) grads_g = self.grad(self.net, weights)(img_A, img_B, sens) return fake_A, fake_B, depend( lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib
def construct(self, ids, wts, labels): weights = self.weights loss = self.network(ids, wts, labels) sens = Fill()(DType()(loss), Shape()(loss), self.sens) grads = self.grad(self.network, weights)(ids, wts, labels, sens) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) return depend(loss, self.optimizer(grads))
def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible, sens=None): """BertSquad""" weights = self.weights init = self.alloc_status() loss = self.network(input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible, self.cast(scaling_sens, ts.float32)) clear_before_grad = self.clear_before_grad(init) loss = P.depend(loss, init) self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.hyper_map(P.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map( P.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) flag = self.get_status(init) flag_sum = self.reduce_sum(init, (0, )) cond = self.less_equal(self.base, flag_sum) P.depend(grads, flag) P.depend(flag, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond) return P.depend(ret, succ)
def construct(self, *args): weights = self.weights loss = self.network(*args) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*args, sens) return P.depend(loss, self.optimizer(grads))