def test(model, ema, args, data): device = torch.device( f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") criterion = nn.CrossEntropyLoss() loss = 0 answers = dict() model.eval() backup_params = EMA(0) for name, param in model.named_parameters(): if param.requires_grad: backup_params.register(name, param.data) param.data.copy_(ema.get(name)) with torch.no_grad(): for batch in iter(data.dev_iter): p1, p2 = model(batch) batch_loss = criterion(p1, batch.s_idx) + criterion( p2, batch.e_idx) loss += batch_loss.item() # (batch, c_len, c_len) batch_size, c_len = p1.size() ls = nn.LogSoftmax(dim=1) mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand( batch_size, -1, -1) score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask score, s_idx = score.max(dim=1) score, e_idx = score.max(dim=1) s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze() for i in range(batch_size): id = batch.id[i] answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1] answer = ' '.join( [data.WORD.vocab.itos[idx] for idx in answer]) answers[id] = answer for name, param in model.named_parameters(): if param.requires_grad: param.data.copy_(backup_params.get(name)) with open(args.prediction_file, 'w', encoding='utf-8') as f: print(json.dumps(answers), file=f) results = evaluate.main(args) return loss, results['exact_match'], results['f1']
def get_vis(args,data): device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") model = BiDAF(args,data.WORD.vocab.vectors).to(device) print ("load Pretrained model") load_path = torch.load(args.load_path) model.load_state_dict(load_path) ema = EMA(args.exp_decay_rate) criterion = nn.CrossEntropyLoss() model.save_mode = True model.eval() def save_vis_data(train=True): save_data = [] iterator = data.train_iter if train else data.dev_iter mode = 'trainData'if train else 'testData' print ('Mode :{}'.format(mode)) save_count = 0 with torch.no_grad() : count =0 for i,batch in enumerate(iterator): present_epoch = int(iterator.epoch) if present_epoch == 1: break; tmp = {} p1,p2 = model(batch) p1 = p1.unsqueeze(0) p2 = p2.unsqueeze(0) c_words = [data.WORD.vocab.itos[i.item()] for i in batch.c_word[0][0]] q_words = [data.WORD.vocab.itos[i.item()] for i in batch.q_word[0][0]] batch_loss = criterion(p1,batch.s_idx) + criterion(p2,batch.e_idx) batch_size, c_len = p1.size() ls = nn.LogSoftmax(dim=1) mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1) score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask scores = score score, s_idx = score.max(dim=1) score, e_idx = score.max(dim=1) s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze() tmp['context'] = c_words tmp['question'] = q_words tmp['gt_s_idx'] = batch.s_idx.cpu().numpy() tmp['gt_e_idx'] = batch.e_idx.cpu().numpy() tmp['save_data'] = model.save_data.copy() tmp['loss'] = batch_loss.cpu().numpy() tmp['prediction_s_idx'] = s_idx.cpu().numpy() tmp['prediction_e_idx'] = e_idx.cpu().numpy() tmp['prediction_scores'] = scores.cpu().numpy() save_data.append(tmp) if len(save_data)%2000 ==0: np.save('{}_{}'.format(count,mode),save_data) save_data = [] count +=1 np.save('{}_{}'.format(count,mode),save_data) save_vis_data() save_vis_data(False)
def main(): parser = argparse.ArgumentParser(description=' FixMatch Training') parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet') parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet') parser.add_argument('--dataset', type=str, default='CIFAR10', help='number of classes in dataset') # parser.add_argument('--n-classes', type=int, default=100, # help='number of classes in dataset') parser.add_argument('--n-labeled', type=int, default=40, help='number of labeled samples for training') parser.add_argument('--n-epoches', type=int, default=1024, help='number of training epoches') parser.add_argument('--batchsize', type=int, default=40, help='train batch size of labeled samples') parser.add_argument('--mu', type=int, default=7, help='factor of train batch size of unlabeled samples') parser.add_argument('--thr', type=float, default=0.95, help='pseudo label threshold') parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, help='number of training images for each epoch') parser.add_argument('--lam-u', type=float, default=1., help='coefficient of unlabeled loss') parser.add_argument('--ema-alpha', type=float, default=0.999, help='decay rate for ema module') parser.add_argument('--lr', type=float, default=0.03, help='learning rate for training') parser.add_argument('--weight-decay', type=float, default=5e-4, help='weight decay') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer') parser.add_argument('--seed', type=int, default=-1, help='seed for random behaviors, no seed if negtive') parser.add_argument('--temperature', type=float, default=0.5, help='temperature for loss function') args = parser.parse_args() # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") logger, writer = setup_default_logging(args) logger.info(dict(args._get_kwargs())) # global settings # torch.multiprocessing.set_sharing_strategy('file_system') if args.seed > 0: torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # torch.backends.cudnn.deterministic = True n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 1024 logger.info("***** Running training *****") logger.info(f" Task = {args.dataset}@{args.n_labeled}") logger.info(f" Num Epochs = {n_iters_per_epoch}") logger.info(f" Batch size per GPU = {args.batchsize}") # logger.info(f" Total train batch size = {args.batch_size * args.world_size}") logger.info(f" Total optimization steps = {n_iters_all}") model, criteria_x, criteria_u, criteria_z = set_model(args) logger.info("Total params: {:.2f}M".format( sum(p.numel() for p in model.parameters()) / 1e6)) dltrain_x, dltrain_u = get_train_loader(args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled) dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2) lb_guessor = LabelGuessor(thresh=args.thr) ema = EMA(model, args.ema_alpha) wd_params, non_wd_params = [], [] for name, param in model.named_parameters(): # if len(param.size()) == 1: if 'bn' in name: non_wd_params.append( param) # bn.weight, bn.bias and classifier.bias # print(name) else: wd_params.append(param) param_list = [{ 'params': wd_params }, { 'params': non_wd_params, 'weight_decay': 0 }] optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0) train_args = dict(model=model, criteria_x=criteria_x, criteria_u=criteria_u, criteria_z=criteria_z, optim=optim, lr_schdlr=lr_schdlr, ema=ema, dltrain_x=dltrain_x, dltrain_u=dltrain_u, lb_guessor=lb_guessor, lambda_u=args.lam_u, n_iters=n_iters_per_epoch, logger=logger) best_acc = -1 best_epoch = 0 logger.info('-----------start training--------------') for epoch in range(args.n_epoches): train_loss, loss_x, loss_u, loss_u_real, loss_simclr, mask_mean = train_one_epoch( epoch, **train_args) # torch.cuda.empty_cache() top1, top5, valid_loss = evaluate(ema, dlval, criteria_x) writer.add_scalars('train/1.loss', { 'train': train_loss, 'test': valid_loss }, epoch) writer.add_scalar('train/2.train_loss_x', loss_x, epoch) writer.add_scalar('train/3.train_loss_u', loss_u, epoch) writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch) writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch) writer.add_scalar('train/5.mask_mean', mask_mean, epoch) writer.add_scalars('test/1.test_acc', { 'top1': top1, 'top5': top5 }, epoch) # writer.add_scalar('test/2.test_loss', loss, epoch) # best_acc = top1 if best_acc < top1 else best_acc if best_acc < top1: best_acc = top1 best_epoch = epoch logger.info( "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}" .format(epoch, top1, top5, best_acc, best_epoch)) writer.close()
def train(args, data): device = torch.device( f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") model = BiDAF(args, data.WORD.vocab.vectors).to(device) ema = EMA(args.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adadelta(parameters, lr=args.learning_rate) criterion = nn.CrossEntropyLoss() writer = SummaryWriter(log_dir='runs/' + args.model_time) model.train() loss, last_epoch = 0, -1 max_dev_exact, max_dev_f1 = -1, -1 iterator = data.train_iter for i, batch in enumerate(iterator): present_epoch = int(iterator.epoch) if present_epoch == args.epoch: break if present_epoch > last_epoch: print('epoch:', present_epoch + 1) last_epoch = present_epoch p1, p2 = model(batch) optimizer.zero_grad() batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) loss += batch_loss.item() batch_loss.backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: ema.update(name, param.data) if (i + 1) % args.print_freq == 0: dev_loss, dev_exact, dev_f1 = test(model, ema, args, data) c = (i + 1) // args.print_freq writer.add_scalar('loss/train', loss, c) writer.add_scalar('loss/dev', dev_loss, c) writer.add_scalar('exact_match/dev', dev_exact, c) writer.add_scalar('f1/dev', dev_f1, c) print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}' f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}') if dev_f1 > max_dev_f1: max_dev_f1 = dev_f1 max_dev_exact = dev_exact best_model = copy.deepcopy(model) loss = 0 model.train() writer.close() print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}') return best_model
def main(): parser = argparse.ArgumentParser(description=' FixMatch Training') parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet') parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet') parser.add_argument('--dataset', type=str, default='CIFAR10', help='number of classes in dataset') # parser.add_argument('--n-classes', type=int, default=100, # help='number of classes in dataset') parser.add_argument('--n-labeled', type=int, default=40, help='number of labeled samples for training') parser.add_argument('--n-epoches', type=int, default=1024, help='number of training epoches') parser.add_argument('--batchsize', type=int, default=40, help='train batch size of labeled samples') parser.add_argument('--mu', type=int, default=7, help='factor of train batch size of unlabeled samples') parser.add_argument('--thr', type=float, default=0.95, help='pseudo label threshold') parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024, help='number of training images for each epoch') parser.add_argument('--lam-u', type=float, default=1., help='coefficient of unlabeled loss') parser.add_argument('--lam-s', type=float, default=0.2, help='coefficient of unlabeled loss SimCLR') parser.add_argument('--ema-alpha', type=float, default=0.999, help='decay rate for ema module') parser.add_argument('--lr', type=float, default=0.03, help='learning rate for training') parser.add_argument('--weight-decay', type=float, default=5e-4, help='weight decay') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer') parser.add_argument('--seed', type=int, default=-1, help='seed for random behaviors, no seed if negtive') parser.add_argument('--temperature', type=float, default=0.5, help='temperature for loss function') args = parser.parse_args() # args.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") logger, writer = setup_default_logging(args) logger.info(dict(args._get_kwargs())) # global settings # torch.multiprocessing.set_sharing_strategy('file_system') if args.seed > 0: torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # torch.backends.cudnn.deterministic = True n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024 n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 1024 logger.info("***** Running training *****") logger.info(f" Task = {args.dataset}@{args.n_labeled}") logger.info(f" Num Epochs = {n_iters_per_epoch}") logger.info(f" Batch size per GPU = {args.batchsize}") # logger.info(f" Total train batch size = {args.batch_size * args.world_size}") logger.info(f" Total optimization steps = {n_iters_all}") model, criteria_x, criteria_u, criteria_z = set_model(args) logger.info("Total params: {:.2f}M".format( sum(p.numel() for p in model.parameters()) / 1e6)) dltrain_x, dltrain_u, dltrain_f = get_train_loader_mix(args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled) dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2) lb_guessor = LabelGuessor(thresh=args.thr) ema = EMA(model, args.ema_alpha) wd_params, non_wd_params = [], [] for name, param in model.named_parameters(): # if len(param.size()) == 1: if 'bn' in name: non_wd_params.append( param) # bn.weight, bn.bias and classifier.bias # print(name) else: wd_params.append(param) param_list = [{ 'params': wd_params }, { 'params': non_wd_params, 'weight_decay': 0 }] optim_fix = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) lr_schdlr_fix = WarmupCosineLrScheduler(optim_fix, max_iter=n_iters_all, warmup_iter=0) train_args = dict(model=model, criteria_x=criteria_x, criteria_u=criteria_u, criteria_z=criteria_z, optim=optim_fix, lr_schdlr=lr_schdlr_fix, ema=ema, dltrain_x=dltrain_x, dltrain_u=dltrain_u, dltrain_f=dltrain_f, lb_guessor=lb_guessor, lambda_u=args.lam_u, lambda_s=args.lam_s, n_iters=n_iters_per_epoch, logger=logger, bt=args.batchsize, mu=args.mu) # # TRAINING PARAMETERS FOR SIMCLR # param_list = [ # {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] # # optim_simclr = torch.optim.SGD(param_list, lr=0.5, weight_decay=args.weight_decay, # momentum=args.momentum, nesterov=False) # # lr_schdlr_simclr = WarmupCosineLrScheduler( # optim_simclr, max_iter=n_iters_all, warmup_iter=0 # ) # # train_args_simclr = dict( # model=model, # criteria_z=criteria_z, # optim=optim_simclr, # lr_schdlr=lr_schdlr_simclr, # ema=ema, # dltrain_f=dltrain_f, # lambda_s=args.lam_s, # n_iters=n_iters_per_epoch, # logger=logger, # bt=args.batchsize, # mu=args.mu # ) # # TRAINING PARAMETERS FOR IIC # param_list = [ # {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] # # optim_iic = torch.optim.Adam(param_list, lr=1e-4, weight_decay=args.weight_decay) # # lr_schdlr_iic = WarmupCosineLrScheduler( # optim_iic, max_iter=n_iters_all, warmup_iter=0 # ) # # train_args_iic = dict( # model=model, # optim=optim_iic, # lr_schdlr=lr_schdlr_iic, # ema=ema, # dltrain_f=dltrain_f, # n_iters=n_iters_per_epoch, # logger=logger, # bt=args.batchsize, # mu=args.mu # ) # best_acc = -1 best_epoch = 0 logger.info('-----------start training--------------') for epoch in range(args.n_epoches): # guardar accuracy de modelo preentrenado hasta espacio h (SALIDA DE BACKBONE) top1, top5, valid_loss = evaluate_linear_Clf(ema, dltrain_x, dlval, criteria_x) writer.add_scalars('test/1.test_linear_acc', { 'top1': top1, 'top5': top5 }, epoch) logger.info("Epoch {}. on h space Top1: {:.4f}. Top5: {:.4f}.".format( epoch, top1, top5)) if epoch < -500: # # FASE DE ENTRENAMIENTO NO SUPERVISADO # entrenar feature representation simclr # train_loss, loss_simclr, model_ = train_one_epoch_simclr(epoch, **train_args_simclr) # writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch) # entrenar iic # train_loss, loss_iic, model_ = train_one_epoch_iic(epoch, **train_args_iic) # writer.add_scalar('train/4.train_loss_iic', loss_iic, epoch) # evaluate_Clf(model_, dltrain_f, dlval, criteria_x) top1, top5, valid_loss = evaluate_linear_Clf( ema, dltrain_x, dlval, criteria_x) # # GUARDAR MODELO ENTRENADO DE FORMA NO SUPERVISADA # if epoch == 497: # # save model # name = 'simclr_trained_good_h2.pt' # torch.save(model_.state_dict(), name) # logger.info('model saved') else: # ENTRENAMIENTO SEMI-SUPERVISADO train_loss, loss_x, loss_u, loss_u_real, mask_mean, loss_simclr = train_one_epoch( epoch, **train_args) top1, top5, valid_loss = evaluate(ema, dlval, criteria_x) writer.add_scalar('train/4.train_loss_simclr', loss_simclr, epoch) writer.add_scalar('train/2.train_loss_x', loss_x, epoch) writer.add_scalar('train/3.train_loss_u', loss_u, epoch) writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch) writer.add_scalar('train/5.mask_mean', mask_mean, epoch) writer.add_scalars('train/1.loss', { 'train': train_loss, 'test': valid_loss }, epoch) writer.add_scalars('test/1.test_acc', { 'top1': top1, 'top5': top5 }, epoch) # writer.add_scalar('test/2.test_loss', loss, epoch) # best_acc = top1 if best_acc < top1 else best_acc if best_acc < top1: best_acc = top1 best_epoch = epoch logger.info( "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}" .format(epoch, top1, top5, best_acc, best_epoch)) writer.close()
model = SE_GCN(args, W2V, MAX_LEN, embed_size, nfeat=v_texts_w2v_idxs_l_list[0].shape[1], nfeat_v=v_features_list[0].shape[1], nfeat_g=len(g_features[0]), nhid_vfeat=args.hidden_vfeat, nhid_siamese=args.hidden_siamese, dropout_vfeat=args.dropout_vfeat, dropout_siamese=args.dropout_siamese, nhid_final=args.hidden_final) summarize_model(model) if args.use_ema: ema = EMA(args.ema_decay) ema.register(model) # optimizer and scheduler parameters = filter(lambda p: p.requires_grad, model.parameters()) # optimizer = optim.SGD(parameters, lr=args.lr, momentum=0.9) optimizer = optim.Adam(params=parameters, lr=args.lr, betas=(args.beta1, args.beta2), eps=1e-8, weight_decay=3e-7) cr = 1.0 / math.log(args.lr_warm_up_num) scheduler = None scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ee: cr * math.log(ee + 1)
def __init__(self, cfg, logger, writer): # Args self.cfg = cfg self.device = torch.device('cuda') self.logger = logger self.writer = writer # Counters self.epoch = 0 self.iter = 0 self.current_MIoU = 0 self.best_MIou = 0 self.best_source_MIou = 0 # Metrics self.evaluator = Eval(self.cfg.data.num_classes) # Loss self.ignore_index = -1 self.loss = nn.CrossEntropyLoss(ignore_index=self.ignore_index) # Model self.model, params = get_model(self.cfg) # self.model = nn.DataParallel(self.model, device_ids=[0]) # TODO: test multi-gpu self.model.to(self.device) # EMA self.ema = EMA(self.model, self.cfg.ema_decay) # Optimizer if self.cfg.opt.kind == "SGD": self.optimizer = torch.optim.SGD( params, momentum=self.cfg.opt.momentum, weight_decay=self.cfg.opt.weight_decay) elif self.cfg.opt.kind == "Adam": self.optimizer = torch.optim.Adam( params, betas=(0.9, 0.99), weight_decay=self.cfg.opt.weight_decay) else: raise NotImplementedError() self.lr_factor = 10 # Source if self.cfg.data.source.dataset == 'synthia': source_train_dataset = SYNTHIA_Dataset( split='train', **self.cfg.data.source.kwargs) source_val_dataset = SYNTHIA_Dataset(split='val', **self.cfg.data.source.kwargs) elif self.cfg.data.source.dataset == 'gta5': source_train_dataset = GTA5_Dataset(split='train', **self.cfg.data.source.kwargs) source_val_dataset = GTA5_Dataset(split='val', **self.cfg.data.source.kwargs) else: raise NotImplementedError() self.source_dataloader = DataLoader(source_train_dataset, shuffle=True, drop_last=True, **self.cfg.data.loader.kwargs) self.source_val_dataloader = DataLoader(source_val_dataset, shuffle=False, drop_last=False, **self.cfg.data.loader.kwargs) # Target if self.cfg.data.target.dataset == 'cityscapes': target_train_dataset = City_Dataset(split='train', **self.cfg.data.target.kwargs) target_val_dataset = City_Dataset(split='val', **self.cfg.data.target.kwargs) else: raise NotImplementedError() self.target_dataloader = DataLoader(target_train_dataset, shuffle=True, drop_last=True, **self.cfg.data.loader.kwargs) self.target_val_dataloader = DataLoader(target_val_dataset, shuffle=False, drop_last=False, **self.cfg.data.loader.kwargs) # Perturbations if self.cfg.lam_aug > 0: self.aug = get_augmentation()
class Trainer(): def __init__(self, cfg, logger, writer): # Args self.cfg = cfg self.device = torch.device('cuda') self.logger = logger self.writer = writer # Counters self.epoch = 0 self.iter = 0 self.current_MIoU = 0 self.best_MIou = 0 self.best_source_MIou = 0 # Metrics self.evaluator = Eval(self.cfg.data.num_classes) # Loss self.ignore_index = -1 self.loss = nn.CrossEntropyLoss(ignore_index=self.ignore_index) # Model self.model, params = get_model(self.cfg) # self.model = nn.DataParallel(self.model, device_ids=[0]) # TODO: test multi-gpu self.model.to(self.device) # EMA self.ema = EMA(self.model, self.cfg.ema_decay) # Optimizer if self.cfg.opt.kind == "SGD": self.optimizer = torch.optim.SGD( params, momentum=self.cfg.opt.momentum, weight_decay=self.cfg.opt.weight_decay) elif self.cfg.opt.kind == "Adam": self.optimizer = torch.optim.Adam( params, betas=(0.9, 0.99), weight_decay=self.cfg.opt.weight_decay) else: raise NotImplementedError() self.lr_factor = 10 # Source if self.cfg.data.source.dataset == 'synthia': source_train_dataset = SYNTHIA_Dataset( split='train', **self.cfg.data.source.kwargs) source_val_dataset = SYNTHIA_Dataset(split='val', **self.cfg.data.source.kwargs) elif self.cfg.data.source.dataset == 'gta5': source_train_dataset = GTA5_Dataset(split='train', **self.cfg.data.source.kwargs) source_val_dataset = GTA5_Dataset(split='val', **self.cfg.data.source.kwargs) else: raise NotImplementedError() self.source_dataloader = DataLoader(source_train_dataset, shuffle=True, drop_last=True, **self.cfg.data.loader.kwargs) self.source_val_dataloader = DataLoader(source_val_dataset, shuffle=False, drop_last=False, **self.cfg.data.loader.kwargs) # Target if self.cfg.data.target.dataset == 'cityscapes': target_train_dataset = City_Dataset(split='train', **self.cfg.data.target.kwargs) target_val_dataset = City_Dataset(split='val', **self.cfg.data.target.kwargs) else: raise NotImplementedError() self.target_dataloader = DataLoader(target_train_dataset, shuffle=True, drop_last=True, **self.cfg.data.loader.kwargs) self.target_val_dataloader = DataLoader(target_val_dataset, shuffle=False, drop_last=False, **self.cfg.data.loader.kwargs) # Perturbations if self.cfg.lam_aug > 0: self.aug = get_augmentation() def train(self): # Loop over epochs self.continue_training = True while self.continue_training: # Train for a single epoch self.train_one_epoch() # Use EMA params to evaluate performance self.ema.apply_shadow() self.ema.model.eval() self.ema.model.cuda() # Validate on source (if possible) and target if self.cfg.data.source_val_iterations > 0: self.validate(mode='source') PA, MPA, MIoU, FWIoU = self.validate() # Restore current (non-EMA) params for training self.ema.restore() # Log val results self.writer.add_scalar('PA', PA, self.epoch) self.writer.add_scalar('MPA', MPA, self.epoch) self.writer.add_scalar('MIoU', MIoU, self.epoch) self.writer.add_scalar('FWIoU', FWIoU, self.epoch) # Save checkpoint if new best model self.current_MIoU = MIoU is_best = MIoU > self.best_MIou if is_best: self.best_MIou = MIoU self.best_iter = self.iter self.logger.info("=> Saving a new best checkpoint...") self.logger.info( "=> The best val MIoU is now {:.3f} from iter {}".format( self.best_MIou, self.best_iter)) self.save_checkpoint('best.pth') else: self.logger.info("=> The MIoU of val did not improve.") self.logger.info( "=> The best val MIoU is still {:.3f} from iter {}".format( self.best_MIou, self.best_iter)) self.epoch += 1 # Save final checkpoint self.logger.info("=> The best MIou was {:.3f} at iter {}".format( self.best_MIou, self.best_iter)) self.logger.info( "=> Saving the final checkpoint to {}".format('final.pth')) self.save_checkpoint('final.pth') def train_one_epoch(self): # Load and reset self.model.train() self.evaluator.reset() # Helper def unpack(x): return (x[0], x[1]) if isinstance(x, tuple) else (x, None) # Training loop total = min(len(self.source_dataloader), len(self.target_dataloader)) for batch_idx, (batch_s, batch_t) in enumerate( tqdm(zip(self.source_dataloader, self.target_dataloader), total=total, desc=f"Epoch {self.epoch + 1}")): # Learning rate self.poly_lr_scheduler(optimizer=self.optimizer) self.writer.add_scalar('train/lr', self.optimizer.param_groups[0]["lr"], self.iter) # Losses losses = {} ########################## # Source supervised loss # ########################## x, y, _ = batch_s if True: # For VS Code collapsing # Data x = x.to(self.device) y = y.squeeze(dim=1).to(device=self.device, dtype=torch.long, non_blocking=True) # Fourier mix: source --> target if self.cfg.source_fourier: x = fourier_mix(src_images=x, tgt_images=batch_t[0].to(self.device), L=self.cfg.fourier_beta) # Forward pred = self.model(x) pred_1, pred_2 = unpack(pred) # Loss (source) loss_source_1 = self.loss(pred_1, y) if self.cfg.aux: loss_source_2 = self.loss(pred_2, y) * self.cfg.lam_aux loss_source = loss_source_1 + loss_source_2 else: loss_source = loss_source_1 # Backward loss_source.backward() # Clean up losses['source_main'] = loss_source_1.cpu().item() if self.cfg.aux: losses['source_aux'] = loss_source_2.cpu().item() del x, y, loss_source, loss_source_1, loss_source_2 ###################### # Target Pseudolabel # ###################### x, _, _ = batch_t x = x.to(self.device) # First step: run non-augmented image though model to get predictions with torch.no_grad(): # Substep 1: forward pass pred = self.model(x.to(self.device)) pred_1, pred_2 = unpack(pred) # Substep 2: convert soft predictions to hard predictions pred_P_1 = F.softmax(pred_1, dim=1) label_1 = torch.argmax(pred_P_1.detach(), dim=1) maxpred_1, argpred_1 = torch.max(pred_P_1.detach(), dim=1) T = self.cfg.pseudolabel_threshold mask_1 = (maxpred_1 > T) ignore_tensor = torch.ones(1).to( self.device, dtype=torch.long) * self.ignore_index label_1 = torch.where(mask_1, label_1, ignore_tensor) if self.cfg.aux: pred_P_2 = F.softmax(pred_2, dim=1) maxpred_2, argpred_2 = torch.max(pred_P_2.detach(), dim=1) pred_c = (pred_P_1 + pred_P_2) / 2 maxpred_c, argpred_c = torch.max(pred_c, dim=1) mask = (maxpred_1 > T) | (maxpred_2 > T) label_2 = torch.where(mask, argpred_c, ignore_tensor) ############ # Aug loss # ############ if self.cfg.lam_aug > 0: # Second step: augment image and label x_aug, y_aug_1 = augment(images=x.cpu(), labels=label_1.detach().cpu(), aug=self.aug) y_aug_1 = y_aug_1.to(device=self.device, non_blocking=True) if self.cfg.aux: _, y_aug_2 = augment(images=x.cpu(), labels=label_2.detach().cpu(), aug=self.aug) y_aug_2 = y_aug_2.to(device=self.device, non_blocking=True) # Third step: run augmented image through model to get predictions pred_aug = self.model(x_aug.to(self.device)) pred_aug_1, pred_aug_2 = unpack(pred_aug) # Fourth step: calculate loss loss_aug_1 = self.loss(pred_aug_1, y_aug_1) * \ self.cfg.lam_aug if self.cfg.aux: loss_aug_2 = self.loss(pred_aug_2, y_aug_2) * \ self.cfg.lam_aug * self.cfg.lam_aux loss_aug = loss_aug_1 + loss_aug_2 else: loss_aug = loss_aug_1 # Backward loss_aug.backward() # Clean up losses['aug_main'] = loss_aug_1.cpu().item() if self.cfg.aux: losses['aug_aux'] = loss_aug_2.cpu().item() del pred_aug, pred_aug_1, pred_aug_2, loss_aug, loss_aug_1, loss_aug_2 ################ # Fourier Loss # ################ if self.cfg.lam_fourier > 0: # Second step: fourier mix x_fourier = fourier_mix(src_images=x.to(self.device), tgt_images=batch_s[0].to(self.device), L=self.cfg.fourier_beta) # Third step: run mixed image through model to get predictions pred_fourier = self.model(x_fourier.to(self.device)) pred_fourier_1, pred_fourier_2 = unpack(pred_fourier) # Fourth step: calculate loss loss_fourier_1 = self.loss(pred_fourier_1, label_1) * \ self.cfg.lam_fourier if self.cfg.aux: loss_fourier_2 = self.loss(pred_fourier_2, label_2) * \ self.cfg.lam_fourier * self.cfg.lam_aux loss_fourier = loss_fourier_1 + loss_fourier_2 else: loss_fourier = loss_fourier_1 # Backward loss_fourier.backward() # Clean up losses['fourier_main'] = loss_fourier_1.cpu().item() if self.cfg.aux: losses['fourier_aux'] = loss_fourier_2.cpu().item() del pred_fourier, pred_fourier_1, pred_fourier_2, loss_fourier, loss_fourier_1, loss_fourier_2 ############### # CutMix Loss # ############### if self.cfg.lam_cutmix > 0: # Second step: CutMix x_cutmix, y_cutmix = cutmix_combine( images_1=x, labels_1=label_1.unsqueeze(dim=1), images_2=batch_s[0].to(self.device), labels_2=batch_s[1].unsqueeze(dim=1).to(self.device, dtype=torch.long)) y_cutmix = y_cutmix.squeeze(dim=1) # Third step: run mixed image through model to get predictions pred_cutmix = self.model(x_cutmix) pred_cutmix_1, pred_cutmix_2 = unpack(pred_cutmix) # Fourth step: calculate loss loss_cutmix_1 = self.loss(pred_cutmix_1, y_cutmix) * \ self.cfg.lam_cutmix if self.cfg.aux: loss_cutmix_2 = self.loss(pred_cutmix_2, y_cutmix) * \ self.cfg.lam_cutmix * self.cfg.lam_aux loss_cutmix = loss_cutmix_1 + loss_cutmix_2 else: loss_cutmix = loss_cutmix_1 # Backward loss_cutmix.backward() # Clean up losses['cutmix_main'] = loss_cutmix_1.cpu().item() if self.cfg.aux: losses['cutmix_aux'] = loss_cutmix_2.cpu().item() del pred_cutmix, pred_cutmix_1, pred_cutmix_2, loss_cutmix, loss_cutmix_1, loss_cutmix_2 ############### # CutMix Loss # ############### # Step optimizer if accumulated enough gradients self.optimizer.step() self.optimizer.zero_grad() # Update model EMA parameters each step self.ema.update_params() # Calculate total loss total_loss = sum(losses.values()) # Log main losses for name, loss in losses.items(): self.writer.add_scalar(f'train/{name}', loss, self.iter) # Log if batch_idx % 100 == 0: log_string = f"[Epoch {self.epoch}]\t" log_string += '\t'.join( [f'{n}: {l:.3f}' for n, l in losses.items()]) self.logger.info(log_string) # Increment global iteration counter self.iter += 1 # End training after finishing iterations if self.iter > self.cfg.opt.iterations: self.continue_training = False return # After each epoch, update model EMA buffers (i.e. batch norm stats) self.ema.update_buffer() @torch.no_grad() def validate(self, mode='target'): """Validate on target""" self.logger.info('Validating') self.evaluator.reset() self.model.eval() # Select dataloader if mode == 'target': val_loader = self.target_val_dataloader elif mode == 'source': val_loader = self.source_val_dataloader else: raise NotImplementedError() # Loop for val_idx, (x, y, id) in enumerate( tqdm(val_loader, desc=f"Val Epoch {self.epoch + 1}")): if mode == 'source' and val_idx >= self.cfg.data.source_val_iterations: break # Forward x = x.to(self.device) y = y.to(device=self.device, dtype=torch.long) pred = self.model(x) if isinstance(pred, tuple): pred = pred[0] # Convert to numpy label = y.squeeze(dim=1).cpu().numpy() argpred = np.argmax(pred.data.cpu().numpy(), axis=1) # Add to evaluator self.evaluator.add_batch(label, argpred) # Tensorboard images vis_imgs = 2 images_inv = inv_preprocess(x.clone().cpu(), vis_imgs, numpy_transform=True) labels_colors = decode_labels(label, vis_imgs) preds_colors = decode_labels(argpred, vis_imgs) for index, (img, lab, predc) in enumerate( zip(images_inv, labels_colors, preds_colors)): self.writer.add_image(str(index) + '/images', img, self.epoch) self.writer.add_image(str(index) + '/labels', lab, self.epoch) self.writer.add_image(str(index) + '/preds', predc, self.epoch) # Calculate and log if self.cfg.data.source.kwargs.class_16: PA = self.evaluator.Pixel_Accuracy() MPA_16, MPA_13 = self.evaluator.Mean_Pixel_Accuracy() MIoU_16, MIoU_13 = self.evaluator.Mean_Intersection_over_Union() FWIoU_16, FWIoU_13 = self.evaluator.Frequency_Weighted_Intersection_over_Union( ) PC_16, PC_13 = self.evaluator.Mean_Precision() self.logger.info( 'Epoch:{:.3f}, PA:{:.3f}, MPA_16:{:.3f}, MIoU_16:{:.3f}, FWIoU_16:{:.3f}, PC_16:{:.3f}' .format(self.epoch, PA, MPA_16, MIoU_16, FWIoU_16, PC_16)) self.logger.info( 'Epoch:{:.3f}, PA:{:.3f}, MPA_13:{:.3f}, MIoU_13:{:.3f}, FWIoU_13:{:.3f}, PC_13:{:.3f}' .format(self.epoch, PA, MPA_13, MIoU_13, FWIoU_13, PC_13)) self.writer.add_scalar('PA', PA, self.epoch) self.writer.add_scalar('MPA_16', MPA_16, self.epoch) self.writer.add_scalar('MIoU_16', MIoU_16, self.epoch) self.writer.add_scalar('FWIoU_16', FWIoU_16, self.epoch) self.writer.add_scalar('MPA_13', MPA_13, self.epoch) self.writer.add_scalar('MIoU_13', MIoU_13, self.epoch) self.writer.add_scalar('FWIoU_13', FWIoU_13, self.epoch) PA, MPA, MIoU, FWIoU = PA, MPA_13, MIoU_13, FWIoU_13 else: PA = self.evaluator.Pixel_Accuracy() MPA = self.evaluator.Mean_Pixel_Accuracy() MIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() PC = self.evaluator.Mean_Precision() self.logger.info( 'Epoch:{:.3f}, PA1:{:.3f}, MPA1:{:.3f}, MIoU1:{:.3f}, FWIoU1:{:.3f}, PC:{:.3f}' .format(self.epoch, PA, MPA, MIoU, FWIoU, PC)) self.writer.add_scalar('PA', PA, self.epoch) self.writer.add_scalar('MPA', MPA, self.epoch) self.writer.add_scalar('MIoU', MIoU, self.epoch) self.writer.add_scalar('FWIoU', FWIoU, self.epoch) return PA, MPA, MIoU, FWIoU def save_checkpoint(self, filename='checkpoint.pth'): torch.save( { 'epoch': self.epoch + 1, 'iter': self.iter, 'state_dict': self.ema.model.state_dict(), 'shadow': self.ema.shadow, 'optimizer': self.optimizer.state_dict(), 'best_MIou': self.best_MIou }, filename) def load_checkpoint(self, filename): checkpoint = torch.load(filename, map_location='cpu') # Get model state dict if not self.cfg.train and 'shadow' in checkpoint: state_dict = checkpoint['shadow'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # Remove DP/DDP if it exists state_dict = { k.replace('module.', ''): v for k, v in state_dict.items() } # Load state dict if hasattr(self.model, 'module'): self.model.module.load_state_dict(state_dict) else: self.model.load_state_dict(state_dict) self.logger.info(f"Model loaded successfully from {filename}") # Load optimizer and epoch if self.cfg.train and self.cfg.model.resume_from_checkpoint: if 'optimizer' in checkpoint: self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info( f"Optimizer loaded successfully from {filename}") if 'epoch' in checkpoint and 'iter' in checkpoint: self.epoch = checkpoint['epoch'] self.iter = checkpoint[ 'iter'] if 'iter' in checkpoint else checkpoint['iteration'] self.logger.info( f"Resuming training from epoch {self.epoch} iter {self.iter}" ) else: self.logger.info(f"Did not resume optimizer") def poly_lr_scheduler(self, optimizer, init_lr=None, iter=None, max_iter=None, power=None): init_lr = self.cfg.opt.lr if init_lr is None else init_lr iter = self.iter if iter is None else iter max_iter = self.cfg.opt.iterations if max_iter is None else max_iter power = self.cfg.opt.poly_power if power is None else power new_lr = init_lr * (1 - float(iter) / max_iter)**power optimizer.param_groups[0]["lr"] = new_lr if len(optimizer.param_groups) == 2: optimizer.param_groups[1]["lr"] = 10 * new_lr