def train(dataloader, net, opt): logger.info("***** Running training *****") logger.info("batch size = %d", args.train_batch_size) sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 gm = GradManager().attach(net.parameters()) for _, batch in enumerate(tqdm(dataloader, desc="Iteration")): input_ids, input_mask, segment_ids, label_ids = tuple( mge.tensor(t) for t in batch) batch_size = input_ids.shape[0] loss, logits, label_ids = net_train(input_ids, segment_ids, input_mask, label_ids, gm=gm, net=net) opt.step().clear_grad() sum_loss += loss.mean().item() sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size total_examples += batch_size total_steps += 1 result = { "train_loss": sum_loss / total_steps, "train_accuracy": sum_accuracy / total_examples, } logger.info("***** Train results *****") for key in sorted(result.keys()): logger.info("%s = %s", key, str(result[key]))
def eval(dataloader, net): logger.info("***** Running evaluation *****") logger.info("batch size = %d", args.eval_batch_size) sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 for _, batch in enumerate(tqdm(dataloader, desc="Iteration")): input_ids, input_mask, segment_ids, label_ids = tuple( mge.tensor(t) for t in batch) batch_size = input_ids.shape[0] if batch_size != args.eval_batch_size: break loss, logits = net_eval(input_ids, segment_ids, input_mask, label_ids, net=net) sum_loss += loss.mean().item() sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size total_examples += batch_size total_steps += 1 result = { "eval_loss": sum_loss / total_steps, "eval_accuracy": sum_accuracy / total_examples, } logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info("%s = %s", key, str(result[key]))
def train_step(image, label): with gm: logits = model(image) loss = F.nn.cross_entropy(logits, label) acc1, acc5 = F.topk_accuracy(logits, label, topk=(1, 5)) gm.backward(loss) opt.step().clear_grad() return loss, acc1, acc5
def train_func(image, label): with gm: model.train() logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) gm.backward(loss) optimizer.step().clear_grad() return loss, acc1, acc5
def valid_func(image, label): model.eval() logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) if dist.is_distributed(): # all_reduce_mean loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size() acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size() acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size() return loss, acc1, acc5
def valid_step(image, label): logits = model(image) loss = F.nn.cross_entropy(logits, label) acc1, acc5 = F.topk_accuracy(logits, label, topk=(1, 5)) # calculate mean values if world_size > 1: loss = F.distributed.all_reduce_sum(loss) / world_size acc1 = F.distributed.all_reduce_sum(acc1) / world_size acc5 = F.distributed.all_reduce_sum(acc5) / world_size return loss, acc1, acc5
def forward(self, embedding, target): origin_logits = self.fc(embedding) one_hot_target = F.one_hot(target, self.num_class).astype("bool") large_margined_logit = F.cos(F.acos(origin_logits) + self.margin) small_margined_logit = origin_logits margined_logit = F.where(origin_logits >= 0, large_margined_logit, small_margined_logit) logits = F.where(one_hot_target, margined_logit, origin_logits) logits = logits * self.scale loss = F.loss.cross_entropy(logits, target) accuracy = F.topk_accuracy(origin_logits, target, topk=1) return loss, accuracy
def forward(self, embedding, target): origin_logits = self.fc(embedding) one_hot_target = F.one_hot(target, self.num_class) # get how much to decrease delta_one_hot_target = one_hot_target * self.margin # apply the decrease logits = origin_logits - delta_one_hot_target logits = logits * self.scale loss = F.loss.cross_entropy(logits, target) accuracy = F.topk_accuracy(origin_logits, target, topk=1) return loss, accuracy
def valid_func(image, label): model.eval() logits = model(image) loss = F.loss.cross_entropy(logits, label, label_smooth=0.1) acc1, acc5 = F.topk_accuracy(logits, label, (1, 5)) return loss, acc1, acc5