def consistency_loss(logits_w, logits_s, name="ce", T=1.0, p_cutoff=0.0, use_hard_labels=True): assert name in ["ce", "L2"] logits_w = logits_w.detach() if name == "L2": assert logits_w.size() == logits_s.size() return F.mse_loss(logits_s, logits_w, reduction="mean") elif name == "L2_mask": pass elif name == "ce": pseudo_label = torch.softmax(logits_w, dim=-1) max_probs, max_idx = torch.max(pseudo_label, dim=-1) mask = max_probs.ge(p_cutoff).float() if use_hard_labels: masked_loss = ( ce_loss(logits_s, max_idx, use_hard_labels, reduction="none") * mask) else: pseudo_label = torch.softmax(logits_w / T, dim=-1) masked_loss = ce_loss(logits_s, pseudo_label, use_hard_labels) * mask return masked_loss.mean(), mask.mean() else: assert Exception("Not Implemented consistency_loss")
def consistency_loss(logits_w, logits_s, name='ce', T=1.0, p_cutoff=0.0, use_hard_labels=True): assert name in ['ce', 'L2'] logits_w = logits_w.detach() if name == 'L2': assert logits_w.size() == logits_s.size() return F.mse_loss(logits_s, logits_w, reduction='mean') elif name == 'L2_mask': pass elif name == 'ce': pseudo_label = torch.softmax(logits_w, dim=-1) max_probs, max_idx = torch.max(pseudo_label, dim=-1) mask = max_probs.ge(p_cutoff).float() if use_hard_labels: masked_loss = ce_loss(logits_s, max_idx, use_hard_labels, reduction='none') * mask else: pseudo_label = torch.softmax(logits_w/T, dim=-1) masked_loss = ce_loss(logits_s, pseudo_label, use_hard_labels) * mask return masked_loss.mean(), mask.mean() else: assert Exception('Not Implemented consistency_loss')
def train(self, args, logger=None): """ Train function of FixMatch. From data_loader, it inference training data, computes losses, and update the networks. """ ngpus_per_node = torch.cuda.device_count() #lb: labeled, ulb: unlabeled self.train_model.train() # for gpu profiling start_batch = torch.cuda.Event(enable_timing=True) end_batch = torch.cuda.Event(enable_timing=True) start_run = torch.cuda.Event(enable_timing=True) end_run = torch.cuda.Event(enable_timing=True) start_batch.record() best_eval_acc, best_it = 0.0, 0 scaler = GradScaler() amp_cm = autocast if args.amp else contextlib.nullcontext for (x_lb, y_lb), (x_ulb_w, x_ulb_s, _) in zip(self.loader_dict['train_lb'], self.loader_dict['train_ulb']): # prevent the training iterations exceed args.num_train_iter if self.it > args.num_train_iter: break end_batch.record() torch.cuda.synchronize() start_run.record() num_lb = x_lb.shape[0] num_ulb = x_ulb_w.shape[0] assert num_ulb == x_ulb_s.shape[0] x_lb, x_ulb_w, x_ulb_s = x_lb.cuda(args.gpu), x_ulb_w.cuda(args.gpu), x_ulb_s.cuda(args.gpu) y_lb = y_lb.cuda(args.gpu) inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s)) # inference and calculate sup/unsup losses with amp_cm(): logits = self.train_model(inputs) logits_x_lb = logits[:num_lb] logits_x_ulb_w, logits_x_ulb_s = logits[num_lb:].chunk(2) del logits # hyper-params for update T = self.t_fn(self.it) p_cutoff = self.p_fn(self.it) sup_loss = ce_loss(logits_x_lb, y_lb, reduction='mean') unsup_loss, mask = consistency_loss(logits_x_ulb_w, logits_x_ulb_s, 'ce', T, p_cutoff, use_hard_labels=args.hard_label) total_loss = sup_loss + self.lambda_u * unsup_loss # parameter updates if args.amp: scaler.scale(total_loss).backward() scaler.step(self.optimizer) scaler.update() else: total_loss.backward() self.optimizer.step() self.scheduler.step() self.train_model.zero_grad() with torch.no_grad(): self._eval_model_update() end_run.record() torch.cuda.synchronize() #tensorboard_dict update tb_dict = {} tb_dict['train/sup_loss'] = sup_loss.detach() tb_dict['train/unsup_loss'] = unsup_loss.detach() tb_dict['train/total_loss'] = total_loss.detach() tb_dict['train/mask_ratio'] = 1.0 - mask.detach() tb_dict['lr'] = self.optimizer.param_groups[0]['lr'] tb_dict['train/prefecth_time'] = start_batch.elapsed_time(end_batch)/1000. tb_dict['train/run_time'] = start_run.elapsed_time(end_run)/1000. if self.it % self.num_eval_iter == 0: eval_dict = self.evaluate(args=args) tb_dict.update(eval_dict) save_path = os.path.join(args.save_dir, args.save_name) if tb_dict['eval/top-1-acc'] > best_eval_acc: best_eval_acc = tb_dict['eval/top-1-acc'] best_it = self.it self.print_fn(f"{self.it} iteration, USE_EMA: {hasattr(self, 'eval_model')}, {tb_dict}, BEST_EVAL_ACC: {best_eval_acc}, at {best_it} iters") if not args.multiprocessing_distributed or \ (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if self.it == best_it: self.save_model('model_best.pth', save_path) if not self.tb_log is None: self.tb_log.update(tb_dict, self.it) self.it +=1 del tb_dict start_batch.record() if self.it > 2**19: self.num_eval_iter = 1000 eval_dict = self.evaluate(args=args) eval_dict.update({'eval/best_acc': best_eval_acc, 'eval/best_it': best_it}) return eval_dict