def train(): model, criteria_x, criteria_u = set_model() n_iters_per_epoch = n_imgs_per_epoch // batchsize dltrain_x, dltrain_u = get_train_loader( batchsize, n_iters_per_epoch, L=250, K=n_guesses ) lb_guessor = LabelGuessor(model, T=temperature) mixuper = MixUp(mixup_alpha) ema = EMA(model, ema_alpha) optim = torch.optim.Adam(model.parameters(), lr=lr) n_iters_per_epoch = n_imgs_per_epoch // batchsize lam_u_epoch = float(lam_u) / n_epoches lam_u_once = lam_u_epoch / n_iters_per_epoch train_args = dict( model=model, criteria_x=criteria_x, criteria_u=criteria_u, optim=optim, ema=ema, wd = 1 - weight_decay * lr, dltrain_x=dltrain_x, dltrain_u=dltrain_u, lb_guessor=lb_guessor, mixuper=mixuper, lambda_u=0, lambda_u_once=lam_u_once, ) best_acc = -1 print('start to train') for e in range(n_epoches): model.train() print('epoch: {}'.format(e)) train_args['lambda_u'] = e * lam_u_epoch train_one_epoch(**train_args) torch.cuda.empty_cache() acc = evaluate(ema) best_acc = acc if best_acc < acc else best_acc log_msg = [ 'epoch: {}'.format(e), 'acc: {:.4f}'.format(acc), 'best_acc: {:.4f}'.format(best_acc)] print(', '.join(log_msg))
def main(): parser = argparse.ArgumentParser(description=' FixMatch Training') parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet') parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet') parser.add_argument('--dataset', type=str, default='CIFAR10', help='number of classes in dataset') # parser.add_argument('--n-classes', type=int, default=100, # help='number of classes in dataset') parser.add_argument('--n-labeled', type=int, default=40, help='number of labeled samples for training') parser.add_argument('--n-epoches', type=int, default=1024, help='number of training epoches') parser.add_argument('--batchsize', type=int, default=40, help='train batch size of labeled samples') parser.add_argument('--mu', type=int, default=7, help='factor of train batch size of unlabeled samples') parser.add_argument('--thr', type=float, default=0.95, help='pseudo label threshold') parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, help='number of training images for each epoch') parser.add_argument('--lam-u', type=float, default=1., help='coefficient of unlabeled loss') parser.add_argument('--ema-alpha', type=float, default=0.999, help='decay rate for ema module') parser.add_argument('--lr', type=float, default=0.03, help='learning rate for training') parser.add_argument('--weight-decay', type=float, default=5e-4, help='weight decay') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer') parser.add_argument('--seed', type=int, default=-1, help='seed for random behaviors, no seed if negtive') parser.add_argument('--temperature', type=float, default=0.5, help='temperature for loss function') args = parser.parse_args() # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") logger, writer = setup_default_logging(args) logger.info(dict(args._get_kwargs())) # global settings # torch.multiprocessing.set_sharing_strategy('file_system') if args.seed > 0: torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # torch.backends.cudnn.deterministic = True n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 1024 logger.info("***** Running training *****") logger.info(f" Task = {args.dataset}@{args.n_labeled}") logger.info(f" Num Epochs = {n_iters_per_epoch}") logger.info(f" Batch size per GPU = {args.batchsize}") # logger.info(f" Total train batch size = {args.batch_size * args.world_size}") logger.info(f" Total optimization steps = {n_iters_all}") model, criteria_x, criteria_u, criteria_z = set_model(args) logger.info("Total params: {:.2f}M".format( sum(p.numel() for p in model.parameters()) / 1e6)) dltrain_x, dltrain_u = get_train_loader(args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled) dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2) lb_guessor = LabelGuessor(thresh=args.thr) ema = EMA(model, args.ema_alpha) wd_params, non_wd_params = [], [] for name, param in model.named_parameters(): # if len(param.size()) == 1: if 'bn' in name: non_wd_params.append( param) # bn.weight, bn.bias and classifier.bias # print(name) else: wd_params.append(param) param_list = [{ 'params': wd_params }, { 'params': non_wd_params, 'weight_decay': 0 }] optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0) train_args = dict(model=model, criteria_x=criteria_x, criteria_u=criteria_u, criteria_z=criteria_z, optim=optim, lr_schdlr=lr_schdlr, ema=ema, dltrain_x=dltrain_x, dltrain_u=dltrain_u, lb_guessor=lb_guessor, lambda_u=args.lam_u, n_iters=n_iters_per_epoch, logger=logger) best_acc = -1 best_epoch = 0 logger.info('-----------start training--------------') for epoch in range(args.n_epoches): train_loss, loss_x, loss_u, loss_u_real, loss_simclr, mask_mean = train_one_epoch( epoch, **train_args) # torch.cuda.empty_cache() top1, top5, valid_loss = evaluate(ema, dlval, criteria_x) writer.add_scalars('train/1.loss', { 'train': train_loss, 'test': valid_loss }, epoch) writer.add_scalar('train/2.train_loss_x', loss_x, epoch) writer.add_scalar('train/3.train_loss_u', loss_u, epoch) writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch) writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch) writer.add_scalar('train/5.mask_mean', mask_mean, epoch) writer.add_scalars('test/1.test_acc', { 'top1': top1, 'top5': top5 }, epoch) # writer.add_scalar('test/2.test_loss', loss, epoch) # best_acc = top1 if best_acc < top1 else best_acc if best_acc < top1: best_acc = top1 best_epoch = epoch logger.info( "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}" .format(epoch, top1, top5, best_acc, best_epoch)) writer.close()
def train(): n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize n_iters_all = n_iters_per_epoch * args.n_epochs #/ args.mu_c epsilon = 0.000001 model, criteria_x, criteria_u = set_model() lb_guessor = LabelGuessor(thresh=args.thr) ema = EMA(model, args.ema_alpha) wd_params, non_wd_params = [], [] for param in model.parameters(): if len(param.size()) == 1: non_wd_params.append(param) else: wd_params.append(param) param_list = [{'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0) dltrain_x, dltrain_u, dltrain_all = get_train_loader(args.batchsize, args.mu, args.mu_c, n_iters_per_epoch, L=args.n_labeled, seed=args.seed) train_args = dict( model=model, criteria_x=criteria_x, criteria_u=criteria_u, optim=optim, lr_schdlr=lr_schdlr, ema=ema, dltrain_x=dltrain_x, dltrain_u=dltrain_u, dltrain_all=dltrain_all, lb_guessor=lb_guessor, ) n_labeled = int(args.n_labeled / args.n_classes) best_acc, top1 = -1, -1 results = {'top 1 acc': [], 'best_acc': []} b_schedule = [args.n_epochs/2, 3*args.n_epochs/4] if args.boot_schedule == 1: step = int(args.n_epochs/3) b_schedule = [step, 2*step] elif args.boot_schedule == 2: step = int(args.n_epochs/4) b_schedule = [step, 2*step, 3*step] for e in range(args.n_epochs): if args.bootstrap > 1 and (e in b_schedule): seed = 99 n_labeled *= args.bootstrap name = sort_unlabeled(ema, n_labeled) print("Bootstrap at epoch ", e," Name = ",name) dltrain_x, dltrain_u, dltrain_all = get_train_loader(args.batchsize, args.mu, args.mu_c, n_iters_per_epoch, L=10*n_labeled, seed=seed, name=name) train_args = dict( model=model, criteria_x=criteria_x, criteria_u=criteria_u, optim=optim, lr_schdlr=lr_schdlr, ema=ema, dltrain_x=dltrain_x, dltrain_u=dltrain_u, dltrain_all=dltrain_all, lb_guessor=lb_guessor, ) model.train() train_one_epoch(**train_args) torch.cuda.empty_cache() if args.test == 0 or args.lam_clr < epsilon: top1 = evaluate(ema) * 100 elif args.test == 1: memory_data = utils.CIFAR10Pair(root='dataset', train=True, transform=utils.test_transform, download=False) memory_data_loader = DataLoader(memory_data, batch_size=args.batchsize, shuffle=False, num_workers=16, pin_memory=True) test_data = utils.CIFAR10Pair(root='dataset', train=False, transform=utils.test_transform, download=False) test_data_loader = DataLoader(test_data, batch_size=args.batchsize, shuffle=False, num_workers=16, pin_memory=True) c = len(memory_data.classes) #10 top1 = test(model, memory_data_loader, test_data_loader, c, e) best_acc = top1 if best_acc < top1 else best_acc results['top 1 acc'].append('{:.4f}'.format(top1)) results['best_acc'].append('{:.4f}'.format(best_acc)) data_frame = pd.DataFrame(data=results) data_frame.to_csv(result_dir + '/' + save_name_pre + '.accuracy.csv', index_label='epoch') log_msg = [ 'epoch: {}'.format(e + 1), 'top 1 acc: {:.4f}'.format(top1), 'best_acc: {:.4f}'.format(best_acc)] print(', '.join(log_msg))
def main(): parser = argparse.ArgumentParser(description=' FixMatch Training') parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet') parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet') parser.add_argument('--dataset', type=str, default='CIFAR10', help='number of classes in dataset') # parser.add_argument('--n-classes', type=int, default=100, # help='number of classes in dataset') parser.add_argument('--n-labeled', type=int, default=40, help='number of labeled samples for training') parser.add_argument('--n-epoches', type=int, default=1024, help='number of training epoches') parser.add_argument('--batchsize', type=int, default=40, help='train batch size of labeled samples') parser.add_argument('--mu', type=int, default=7, help='factor of train batch size of unlabeled samples') parser.add_argument('--thr', type=float, default=0.95, help='pseudo label threshold') parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, help='number of training images for each epoch') parser.add_argument('--lam-u', type=float, default=1., help='coefficient of unlabeled loss') parser.add_argument('--lam-s', type=float, default=0.2, help='coefficient of unlabeled loss SimCLR') parser.add_argument('--ema-alpha', type=float, default=0.999, help='decay rate for ema module') parser.add_argument('--lr', type=float, default=0.03, help='learning rate for training') parser.add_argument('--weight-decay', type=float, default=5e-4, help='weight decay') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer') parser.add_argument('--seed', type=int, default=-1, help='seed for random behaviors, no seed if negtive') parser.add_argument('--temperature', type=float, default=0.5, help='temperature for loss function') args = parser.parse_args() # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") logger, writer = setup_default_logging(args) logger.info(dict(args._get_kwargs())) # global settings # torch.multiprocessing.set_sharing_strategy('file_system') if args.seed > 0: torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # torch.backends.cudnn.deterministic = True n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 1024 logger.info("***** Running training *****") logger.info(f" Task = {args.dataset}@{args.n_labeled}") logger.info(f" Num Epochs = {n_iters_per_epoch}") logger.info(f" Batch size per GPU = {args.batchsize}") # logger.info(f" Total train batch size = {args.batch_size * args.world_size}") logger.info(f" Total optimization steps = {n_iters_all}") model, criteria_x, criteria_u, criteria_z = set_model(args) logger.info("Total params: {:.2f}M".format( sum(p.numel() for p in model.parameters()) / 1e6)) dltrain_x, dltrain_u, dltrain_f = get_train_loader_mix(args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled) dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2) lb_guessor = LabelGuessor(thresh=args.thr) ema = EMA(model, args.ema_alpha) wd_params, non_wd_params = [], [] for name, param in model.named_parameters(): # if len(param.size()) == 1: if 'bn' in name: non_wd_params.append( param) # bn.weight, bn.bias and classifier.bias # print(name) else: wd_params.append(param) param_list = [{ 'params': wd_params }, { 'params': non_wd_params, 'weight_decay': 0 }] optim_fix = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) lr_schdlr_fix = WarmupCosineLrScheduler(optim_fix, max_iter=n_iters_all, warmup_iter=0) train_args = dict(model=model, criteria_x=criteria_x, criteria_u=criteria_u, criteria_z=criteria_z, optim=optim_fix, lr_schdlr=lr_schdlr_fix, ema=ema, dltrain_x=dltrain_x, dltrain_u=dltrain_u, dltrain_f=dltrain_f, lb_guessor=lb_guessor, lambda_u=args.lam_u, lambda_s=args.lam_s, n_iters=n_iters_per_epoch, logger=logger, bt=args.batchsize, mu=args.mu) # # TRAINING PARAMETERS FOR SIMCLR # param_list = [ # {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] # # optim_simclr = torch.optim.SGD(param_list, lr=0.5, weight_decay=args.weight_decay, # momentum=args.momentum, nesterov=False) # # lr_schdlr_simclr = WarmupCosineLrScheduler( # optim_simclr, max_iter=n_iters_all, warmup_iter=0 # ) # # train_args_simclr = dict( # model=model, # criteria_z=criteria_z, # optim=optim_simclr, # lr_schdlr=lr_schdlr_simclr, # ema=ema, # dltrain_f=dltrain_f, # lambda_s=args.lam_s, # n_iters=n_iters_per_epoch, # logger=logger, # bt=args.batchsize, # mu=args.mu # ) # # TRAINING PARAMETERS FOR IIC # param_list = [ # {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] # # optim_iic = torch.optim.Adam(param_list, lr=1e-4, weight_decay=args.weight_decay) # # lr_schdlr_iic = WarmupCosineLrScheduler( # optim_iic, max_iter=n_iters_all, warmup_iter=0 # ) # # train_args_iic = dict( # model=model, # optim=optim_iic, # lr_schdlr=lr_schdlr_iic, # ema=ema, # dltrain_f=dltrain_f, # n_iters=n_iters_per_epoch, # logger=logger, # bt=args.batchsize, # mu=args.mu # ) # best_acc = -1 best_epoch = 0 logger.info('-----------start training--------------') for epoch in range(args.n_epoches): # guardar accuracy de modelo preentrenado hasta espacio h (SALIDA DE BACKBONE) top1, top5, valid_loss = evaluate_linear_Clf(ema, dltrain_x, dlval, criteria_x) writer.add_scalars('test/1.test_linear_acc', { 'top1': top1, 'top5': top5 }, epoch) logger.info("Epoch {}. on h space Top1: {:.4f}. Top5: {:.4f}.".format( epoch, top1, top5)) if epoch < -500: # # FASE DE ENTRENAMIENTO NO SUPERVISADO # entrenar feature representation simclr # train_loss, loss_simclr, model_ = train_one_epoch_simclr(epoch, **train_args_simclr) # writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch) # entrenar iic # train_loss, loss_iic, model_ = train_one_epoch_iic(epoch, **train_args_iic) # writer.add_scalar('train/4.train_loss_iic', loss_iic, epoch) # evaluate_Clf(model_, dltrain_f, dlval, criteria_x) top1, top5, valid_loss = evaluate_linear_Clf( ema, dltrain_x, dlval, criteria_x) # # GUARDAR MODELO ENTRENADO DE FORMA NO SUPERVISADA # if epoch == 497: # # save model # name = 'simclr_trained_good_h2.pt' # torch.save(model_.state_dict(), name) # logger.info('model saved') else: # ENTRENAMIENTO SEMI-SUPERVISADO train_loss, loss_x, loss_u, loss_u_real, mask_mean, loss_simclr = train_one_epoch( epoch, **train_args) top1, top5, valid_loss = evaluate(ema, dlval, criteria_x) writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch) writer.add_scalar('train/2.train_loss_x', loss_x, epoch) writer.add_scalar('train/3.train_loss_u', loss_u, epoch) writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch) writer.add_scalar('train/5.mask_mean', mask_mean, epoch) writer.add_scalars('train/1.loss', { 'train': train_loss, 'test': valid_loss }, epoch) writer.add_scalars('test/1.test_acc', { 'top1': top1, 'top5': top5 }, epoch) # writer.add_scalar('test/2.test_loss', loss, epoch) # best_acc = top1 if best_acc < top1 else best_acc if best_acc < top1: best_acc = top1 best_epoch = epoch logger.info( "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}" .format(epoch, top1, top5, best_acc, best_epoch)) writer.close()
def __init__(self, pwd): self.model = WideResnet(cfg.n_classes, k=cfg.wresnet_k, n=cfg.wresnet_n, batchsize=cfg.batch_size) self.model = self.model.cuda() self.unlabeled_trainloader, self.val_loader = loading_data() self.labeled_trainloader = update_loading(epoch=0) wd_params, non_wd_params = [], [] for param in self.model.parameters(): if len(param.size()) == 1: non_wd_params.append(param) else: wd_params.append(param) param_list = [{ 'params': wd_params }, { 'params': non_wd_params, 'weight_decay': 0 }] self.optimizer = torch.optim.SGD(param_list, lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum, nesterov=True) self.n_iters_per_epoch = cfg.n_imgs_per_epoch // cfg.batch_size self.lr_schdlr = WarmupCosineLrScheduler( self.optimizer, max_iter=self.n_iters_per_epoch * cfg.n_epoches, warmup_iter=0) self.lb_guessor = LabelGuessor(args=cfg) self.train_record = { 'best_acc1': 0, 'best_model_name': '', 'last_model_name': '' } self.cross_entropy = nn.CrossEntropyLoss().cuda() self.i_tb = 0 self.epoch = 0 self.exp_name = cfg.exp_name self.exp_path = cfg.exp_path if cfg.resume: print('Loaded resume weights for WideResnet') latest_state = torch.load(cfg.resume_model) self.model.load_state_dict(latest_state['net']) self.optimizer.load_state_dict(latest_state['optimizer']) self.lr_schdlr.load_state_dict(latest_state['scheduler']) self.epoch = latest_state['epoch'] + 1 self.i_tb = latest_state['i_tb'] self.train_record = latest_state['train_record'] self.exp_path = latest_state['exp_path'] self.exp_name = latest_state['exp_name'] self.ema = EMA(self.model, cfg.ema_alpha) self.writer, self.log_txt = logger( cfg.exp_path, cfg.exp_name, pwd, ['exp', 'dataset', 'pretrained', 'pre_trained'])
class Trainer(): def __init__(self, pwd): self.model = WideResnet(cfg.n_classes, k=cfg.wresnet_k, n=cfg.wresnet_n, batchsize=cfg.batch_size) self.model = self.model.cuda() self.unlabeled_trainloader, self.val_loader = loading_data() self.labeled_trainloader = update_loading(epoch=0) wd_params, non_wd_params = [], [] for param in self.model.parameters(): if len(param.size()) == 1: non_wd_params.append(param) else: wd_params.append(param) param_list = [{ 'params': wd_params }, { 'params': non_wd_params, 'weight_decay': 0 }] self.optimizer = torch.optim.SGD(param_list, lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum, nesterov=True) self.n_iters_per_epoch = cfg.n_imgs_per_epoch // cfg.batch_size self.lr_schdlr = WarmupCosineLrScheduler( self.optimizer, max_iter=self.n_iters_per_epoch * cfg.n_epoches, warmup_iter=0) self.lb_guessor = LabelGuessor(args=cfg) self.train_record = { 'best_acc1': 0, 'best_model_name': '', 'last_model_name': '' } self.cross_entropy = nn.CrossEntropyLoss().cuda() self.i_tb = 0 self.epoch = 0 self.exp_name = cfg.exp_name self.exp_path = cfg.exp_path if cfg.resume: print('Loaded resume weights for WideResnet') latest_state = torch.load(cfg.resume_model) self.model.load_state_dict(latest_state['net']) self.optimizer.load_state_dict(latest_state['optimizer']) self.lr_schdlr.load_state_dict(latest_state['scheduler']) self.epoch = latest_state['epoch'] + 1 self.i_tb = latest_state['i_tb'] self.train_record = latest_state['train_record'] self.exp_path = latest_state['exp_path'] self.exp_name = latest_state['exp_name'] self.ema = EMA(self.model, cfg.ema_alpha) self.writer, self.log_txt = logger( cfg.exp_path, cfg.exp_name, pwd, ['exp', 'dataset', 'pretrained', 'pre_trained']) def forward(self): print('start to train') for epoch in range(self.epoch, cfg.n_epoches): self.epoch = epoch print(('=' * 50 + 'epoch: {}' + '=' * 50).format(self.epoch + 1)) self.train() torch.cuda.empty_cache() self.evaluate(self.ema) def train(self): if self.epoch > cfg.start_add_samples_epoch: indexs, pre_targets = self.lb_guessor.label_generator.new_data() self.labeled_trainloader = update_loading( copy.deepcopy(indexs), copy.deepcopy(pre_targets), self.epoch) self.lb_guessor.init_for_add_sample(self.epoch, cfg.start_add_samples_epoch) self.model.train() Loss,Loss_L, Loss_U, Loss_U_Real, Loss_MI=AverageMeter(),AverageMeter(),AverageMeter(),\ AverageMeter(),AverageMeter() Correct_Num, Valid_Num = AverageMeter(), AverageMeter() st = time.time() l_set, u_set = iter(self.labeled_trainloader), iter( self.unlabeled_trainloader) for it in range(self.n_iters_per_epoch): (img, img_l_weak, img_l_strong), lbs_l = next(l_set) (img_u, img_u_weak, img_u_strong), lbs_u_real, index_u = next(u_set) img_l_weak, img_l_strong, lbs_l = img_l_weak.cuda( ), img_l_strong.cuda(), lbs_l.cuda() img_u, img_u_weak, img_u_strong = img_u.cuda(), img_u_weak.cuda( ), img_u_strong.cuda() lbs_u, valid_u = self.lb_guessor(self.model, img_l_weak, img_u_weak, lbs_l, index_u) n_u = img_u_strong.size(0) img_cat = torch.cat([img_l_weak, img_u, img_u_weak, img_u_strong], dim=0).detach() _, __, pred_l, pred_u = self.model(img_cat) pred_u_o, pred_u_w, pred_u_s = pred_u[:n_u], pred_u[ n_u:2 * n_u], pred_u[2 * n_u:] #=====================cross-entropy loss for labeled data============== loss_l = self.cross_entropy(pred_l, lbs_l) # =====================T-MI loss for unlabeled data============== if self.epoch >= 20: T_MI_loss = Triplet_MI_loss(pred_u_o, pred_u_w, pred_u_s) else: T_MI_loss = torch.tensor(0) # =====================cross-entropy loss for unlabeled data============== if lbs_u.size(0) > 0 and self.epoch >= 2: pred_u_s = pred_u_s[valid_u] loss_u = self.cross_entropy(pred_u_s, lbs_u) with torch.no_grad(): lbs_u_real = lbs_u_real[valid_u].cuda() valid_num = lbs_u_real.size(0) corr_lb = (lbs_u_real == lbs_u) loss_u_real = F.cross_entropy(pred_u_s, lbs_u_real) else: loss_u = torch.tensor(0) loss_u_real = torch.tensor(0) corr_lb = torch.tensor(0) valid_num = 0 loss = loss_l + cfg.lam_u * loss_u + 0.1 * T_MI_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.ema.update_params() self.lr_schdlr.step() Loss.update(loss.item()) Loss_L.update(loss_l.item()) Loss_U.update(loss_u.item()) Loss_U_Real.update(loss_u_real.item()) Loss_MI.update(T_MI_loss.item()) Correct_Num.update(corr_lb.sum().item()) Valid_Num.update(valid_num) if (it + 1) % 256 == 0: self.i_tb += 1 self.writer.add_scalar('loss_u', Loss_U.avg, self.i_tb) self.writer.add_scalar('loss_MI', Loss_MI.avg, self.i_tb) ed = time.time() t = ed - st lr_log = [pg['lr'] for pg in self.optimizer.param_groups] lr_log = sum(lr_log) / len(lr_log) msg = ', '.join([ ' [iter: {}', 'loss: {:.3f}', 'loss_l: {:.4f}', 'loss_u: {:.4f}', 'loss_u_real: {:.4f}', 'loss_MI: {:.4f}', 'correct: {}/{}', 'lr: {:.4f}', 'time: {:.2f}]', ]).format(it + 1, Loss.avg, Loss_L.avg, Loss_U.avg, Loss_U_Real.avg, Loss_MI.avg, int(Correct_Num.avg), int(Valid_Num.avg), lr_log, t) st = ed print(msg) self.ema.update_buffer() self.writer.add_scalar( 'acc_overall', Correct_Num.sum / (cfg.n_imgs_per_epoch * cfg.mu), self.epoch + 1) self.writer.add_scalar('acc_in_labeled', Correct_Num.sum / (Valid_Num.sum + 1e-10), self.epoch + 1) def evaluate(self, ema): ema.apply_shadow() ema.model.eval() ema.model.cuda() matches = [] for ims, lbs in self.val_loader: ims = ims.cuda() lbs = lbs.cuda() with torch.no_grad(): __, preds = ema.model(ims, mode='val') scores = torch.softmax(preds, dim=1) _, preds = torch.max(scores, dim=1) match = lbs == preds matches.append(match) matches = torch.cat(matches, dim=0).float() acc = torch.mean(matches) self.writer.add_scalar('val_acc', acc, self.epoch) self.train_record = update_model(ema.model, self.optimizer, self.lr_schdlr, self.epoch, self.i_tb, self.exp_path, self.exp_name, acc, self.train_record) print_summary(cfg.exp_name, acc, self.train_record) ema.restore()
def train(): n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize n_iters_all = n_iters_per_epoch * args.n_epochs model, criteria_x, criteria_u = set_model() dltrain_x, dltrain_u = get_train_loader(args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled, seed=args.seed) lb_guessor = LabelGuessor(thresh=args.thr) ema = EMA(model, args.ema_alpha) wd_params, non_wd_params = [], [] for param in model.parameters(): if len(param.size()) == 1: non_wd_params.append(param) else: wd_params.append(param) param_list = [{ 'params': wd_params }, { 'params': non_wd_params, 'weight_decay': 0 }] optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0) train_args = dict( model=model, criteria_x=criteria_x, criteria_u=criteria_u, optim=optim, lr_schdlr=lr_schdlr, ema=ema, dltrain_x=dltrain_x, dltrain_u=dltrain_u, lb_guessor=lb_guessor, lambda_u=args.lam_u, lambda_c=args.lam_c, n_iters=n_iters_per_epoch, ) best_acc = -1 print('start to train') for e in range(args.n_epochs): model.train() print('epoch: {}'.format(e + 1)) train_one_epoch(**train_args) torch.cuda.empty_cache() acc = evaluate(ema) best_acc = acc if best_acc < acc else best_acc log_msg = [ 'epoch: {}'.format(e), 'acc: {:.4f}'.format(acc), 'best_acc: {:.4f}'.format(best_acc) ] print(', '.join(log_msg)) sort_unlabeled(ema)