def main(): # torch.manual_seed(1234) # torch.cuda.manual_seed(1234) opt = TrainOptions() args = opt.initialize() _t = {'iter time': Timer()} model_name = args.source + '_to_' + args.target if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) os.makedirs(os.path.join(args.snapshot_dir, 'logs')) opt.print_options(args) sourceloader, targetloader = CreateSrcDataLoader( args), CreateTrgDataLoader(args) targetloader_iter, sourceloader_iter = iter(targetloader), iter( sourceloader) model, optimizer = CreateModel(args) model_D, optimizer_D = CreateDiscriminator(args) start_iter = 0 if args.restore_from is not None: start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1]) train_writer = tensorboardX.SummaryWriter( os.path.join(args.snapshot_dir, "logs", model_name)) bce_loss = torch.nn.BCEWithLogitsLoss() l1_loss = torch.nn.L1Loss() cos_loss = torch.nn.CosineSimilarity(dim=0, eps=1e-06) cudnn.enabled = True cudnn.benchmark = True model.train() model.cuda() model_D.train() model_D.cuda() loss = [ 'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real', 'loss_D_trg_real' ] _t['iter time'].tic() pbar = tqdm(range(start_iter, args.num_steps_stop)) #for i in range(start_iter, args.num_steps): for i in pbar: model.adjust_learning_rate(args, optimizer, i) model_D.adjust_learning_rate(args, optimizer_D, i) optimizer.zero_grad() optimizer_D.zero_grad() for param in model_D.parameters(): param.requires_grad = False src_img, src_lbl, _, _ = sourceloader_iter.next() src_img, src_lbl = Variable(src_img).cuda(), Variable( src_lbl.long()).cuda() src_seg_score, src_seg_score2 = model(src_img) loss_seg_src1 = CrossEntropy2d(src_seg_score, src_lbl) loss_seg_src2 = CrossEntropy2d(src_seg_score2, src_lbl) loss_seg_src = loss_seg_src1 + loss_seg_src2 loss_seg_src.backward() if args.data_label_folder_target is not None: trg_img, trg_lbl, _, _ = targetloader_iter.next() trg_img, trg_lbl = Variable(trg_img).cuda(), Variable( trg_lbl.long()).cuda() trg_seg_score = model(trg_img) loss_seg_trg = model.loss else: trg_img, _, name = targetloader_iter.next() trg_img = Variable(trg_img).cuda() trg_seg_score, trg_seg_score2 = model(trg_img) loss_seg_trg = 0 outD_trg = model_D(F.softmax(trg_seg_score)) outD_trg2 = model_D(F.softmax(trg_seg_score2)) loss_D_trg_fake1 = bce_loss( outD_trg, Variable(torch.FloatTensor(outD_trg.data.size()).fill_(0)).cuda()) loss_D_trg_fake2 = bce_loss( outD_trg2, Variable(torch.FloatTensor(outD_trg2.data.size()).fill_(0)).cuda()) loss_D_trg_fake = loss_D_trg_fake1 + loss_D_trg_fake2 loss_agree = l1_loss(F.softmax(trg_seg_score), F.softmax(trg_seg_score2)) loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_seg_trg + loss_agree loss_trg.backward() #Weight Discrepancy Loss W5 = None W6 = None if args.model == 'DeepLab2': for (w5, w6) in zip(model.layer5.parameters(), model.layer6.parameters()): if W5 is None and W6 is None: W5 = w5.view(-1) W6 = w6.view(-1) else: W5 = torch.cat((W5, w5.view(-1)), 0) W6 = torch.cat((W6, w6.view(-1)), 0) #ipdb.set_trace() #loss_weight = (torch.matmul(W5, W6) / (torch.norm(W5) * torch.norm(W6)) + 1) # +1 is for a positive loss # loss_weight = loss_weight * damping * 2 loss_weight = args.weight_div * (cos_loss(W5, W6) + 1) loss_weight.backward() for param in model_D.parameters(): param.requires_grad = True src_seg_score, trg_seg_score = src_seg_score.detach( ), trg_seg_score.detach() src_seg_score2, trg_seg_score2 = src_seg_score2.detach( ), trg_seg_score2.detach() outD_src = model_D(F.softmax(src_seg_score)) loss_D_src_real1 = bce_loss( outD_src, Variable(torch.FloatTensor( outD_src.data.size()).fill_(0)).cuda()) / 2 outD_src2 = model_D(F.softmax(src_seg_score2)) loss_D_src_real2 = bce_loss( outD_src2, Variable(torch.FloatTensor( outD_src2.data.size()).fill_(0)).cuda()) / 2 loss_D_src_real = loss_D_src_real1 + loss_D_src_real2 loss_D_src_real.backward() outD_trg = model_D(F.softmax(trg_seg_score)) loss_D_trg_real1 = bce_loss( outD_trg, Variable(torch.FloatTensor( outD_trg.data.size()).fill_(1)).cuda()) / 2 outD_trg2 = model_D(F.softmax(trg_seg_score2)) loss_D_trg_real2 = bce_loss( outD_trg2, Variable(torch.FloatTensor( outD_trg2.data.size()).fill_(1)).cuda()) / 2 loss_D_trg_real = loss_D_trg_real1 + loss_D_trg_real2 loss_D_trg_real.backward() d_loss = loss_D_src_real.data + loss_D_trg_real.data optimizer.step() optimizer_D.step() for m in loss: train_writer.add_scalar(m, eval(m), i + 1) if (i + 1) % args.save_pred_every == 0: print 'taking snapshot ...' torch.save( model.state_dict(), os.path.join(args.snapshot_dir, '%s_' % (args.source) + str(i + 1) + '.pth')) torch.save( model_D.state_dict(), os.path.join(args.snapshot_dir, '%s_' % (args.source) + str(i + 1) + '_D.pth')) if (i + 1) % args.print_freq == 0: _t['iter time'].toc(average=False) print '[it %d][src seg loss %.4f][adv loss %.4f][d loss %.4f][agree loss %.4f][div loss %.4f][lr %.4f][%.2fs]' % \ (i + 1, loss_seg_src.data, loss_D_trg_fake.data,d_loss,loss_agree.data,loss_weight.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff) if i + 1 > args.num_steps_stop: print 'finish training' break _t['iter time'].tic()
def main(): # torch.manual_seed(1234) # torch.cuda.manual_seed(1234) opt = TrainOptions() args = opt.initialize() _t = {'iter time' : Timer()} model_name = args.source + '_to_' + args.target if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) os.makedirs(os.path.join(args.snapshot_dir, 'logs')) opt.print_options(args) sourceloader, targetloader = CreateSrcDataLoader(args), CreateTrgDataLoader(args) targetloader_iter, sourceloader_iter = iter(targetloader), iter(sourceloader) model, optimizer = CreateModel(args) model_D, optimizer_D = CreateDiscriminator(args) start_iter = 0 if args.restore_from is not None: start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1]) train_writer = tensorboardX.SummaryWriter(os.path.join(args.snapshot_dir, "logs", model_name)) bce_loss = torch.nn.BCEWithLogitsLoss() cent_loss=ConditionalEntropyLoss() cudnn.enabled = True cudnn.benchmark = True model.train() model.cuda() model_D.train() model_D.cuda() loss = ['loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real', 'loss_D_trg_real'] _t['iter time'].tic() pbar = tqdm(range(start_iter,args.num_steps_stop)) #for i in range(start_iter, args.num_steps): for i in pbar: model.adjust_learning_rate(args, optimizer, i) model_D.adjust_learning_rate(args, optimizer_D, i) optimizer.zero_grad() optimizer_D.zero_grad() for param in model_D.parameters(): param.requires_grad = False src_img, src_lbl, _, _ = sourceloader_iter.next() src_img, src_lbl = Variable(src_img).cuda(), Variable(src_lbl.long()).cuda() src_seg_score = model(src_img) loss_seg_src = CrossEntropy2d(src_seg_score, src_lbl) #loss_seg_src = model.loss loss_seg_src.backward() if args.data_label_folder_target is not None: trg_img, trg_lbl, _, _ = targetloader_iter.next() trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(trg_lbl.long()).cuda() trg_seg_score = model(trg_img) loss_seg_trg = model.loss else: trg_img, _, name = targetloader_iter.next() trg_img = Variable(trg_img).cuda() trg_seg_score = model(trg_img) #ipdb.set_trace() loss_seg_trg= cent_loss(trg_seg_score) #loss_seg_trg= entropy_loss(F.softmax(trg_seg_score)) #loss_seg_trg = 0 outD_trg = model_D(F.softmax(trg_seg_score)) loss_D_trg_fake = bce_loss(outD_trg, Variable(torch.FloatTensor(outD_trg.data.size()).fill_(0)).cuda()) #loss_D_trg_fake = model_D.loss loss_trg = args.lambda_adv_target * (loss_D_trg_fake + loss_seg_trg) loss_trg.backward() for param in model_D.parameters(): param.requires_grad = True src_seg_score, trg_seg_score = src_seg_score.detach(), trg_seg_score.detach() outD_src = model_D(F.softmax(src_seg_score)) loss_D_src_real = bce_loss(outD_src, Variable(torch.FloatTensor(outD_src.data.size()).fill_(0)).cuda())/ 2 #loss_D_src_real = model_D.loss / 2 loss_D_src_real.backward() outD_trg = model_D(F.softmax(trg_seg_score)) loss_D_trg_real = bce_loss(outD_trg, Variable(torch.FloatTensor(outD_trg.data.size()).fill_(1)).cuda())/ 2 #loss_D_trg_real = model_D.loss / 2 loss_D_trg_real.backward() d_loss=loss_D_src_real.data+ loss_D_trg_real.data optimizer.step() optimizer_D.step() for m in loss: train_writer.add_scalar(m, eval(m), i+1) if (i+1) % args.save_pred_every == 0: print 'taking snapshot ...' torch.save(model.state_dict(), os.path.join(args.snapshot_dir, '%s_' %(args.source) +str(i+1)+'.pth' )) torch.save(model_D.state_dict(), os.path.join(args.snapshot_dir, '%s_' %(args.source) +str(i+1)+'_D.pth' )) if (i+1) % args.print_freq == 0: _t['iter time'].toc(average=False) print '[it %d][src seg loss %.4f][adv loss %.4f][d loss %.4f][lr %.4f][%.2fs]' % \ (i + 1, loss_seg_src.data, loss_D_trg_fake.data,d_loss,optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff) if i + 1 > args.num_steps_stop: print 'finish training' break _t['iter time'].tic()
def train(cfg): # configure train train_name = time.strftime("%m%d_%H%M%S", time.localtime( )) + '_' + cfg.model + '_' + os.path.basename(cfg.dataset_path) cfg.name = train_name log_interval = int(np.ceil(cfg.max_epochs * 0.1)) print(cfg) # cpu or gpu? if torch.cuda.is_available() and cfg.device is not None: device = torch.device(cfg.device) else: if not torch.cuda.is_available(): print("hey man, buy a GPU!") device = torch.device("cpu") # data print('Loading Data') train_data = monoSimDataset(path=cfg.dataset_path, mode='train', seed=cfg.seed, debug_data=cfg.debug) train_data_loader = DataLoader(train_data, cfg.batch_size, drop_last=True, shuffle=True, num_workers=cfg.num_workers) val_data = monoSimDataset(path=cfg.dataset_path, mode='val', seed=cfg.seed, debug_data=cfg.debug) val_data_loader = DataLoader(val_data, cfg.batch_size, shuffle=False, drop_last=True, num_workers=cfg.num_workers) # configure model print('Loading Model') model = MobileNetV2_Lite(True, cfg.mask_learn_rate) assert model is not None model.to(device) if cfg.cp_path: cp_data = torch.load(cfg.cp_path, map_location=device) try: model.load_state_dict(cp_data['model']) except Exception as e: model.load_state_dict(cp_data['model'], strict=False) print(e) cp_data['cfg'] = '' if 'cfg' not in cp_data else cp_data['cfg'] print(cp_data['cfg']) # criterion and optimizer optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, # momentum=cfg.momentum, weight_decay=cfg.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, verbose=True) pred_criterion = nn.MSELoss() mask_criterion = CrossEntropy2d() # checkpoint if cfg.cp_num > 0: cp_dir_path = os.path.normcase(os.path.join('checkpoints', train_name)) os.mkdir(cp_dir_path) best_cp = [] history_dir_path = os.path.normcase( os.path.join(cp_dir_path, 'history')) os.mkdir(history_dir_path) with open(os.path.normcase(os.path.join(cp_dir_path, 'config.txt')), 'w') as f: info = str(cfg) + '#' * 30 + '\npre_cfg:\n' + str( cp_data['cfg']) if cfg.cp_path else str(cfg) f.write(info) # visble if cfg.visible: log_writer = SummaryWriter(os.path.join("log", train_name)) log_writer.add_text('cur_cfg', cfg.__str__()) if cfg.cp_path: log_writer.add_text('pre_cfg', cp_data['cfg'].__str__()) # Start! print("Start training!\n") for epoch in range(1, cfg.max_epochs + 1): if epoch % int(cfg.max_epochs / 10) == 0 and cfg.mask_lr_decay < 1: cfg.mask_learn_rate *= cfg.mask_lr_decay print("[{}] Mask learn rate: {:.4e}".format( epoch, cfg.mask_learn_rate)) # train model.train() epoch_loss = 0 for img, mask, target in tqdm( train_data_loader, desc='[{}] mini_batch'.format(epoch), bar_format='{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'): img = img.to(device) mask = mask.to(device) target = target.to(device) optimizer.zero_grad() pred, heatmap = model(img) if cfg.mask_learn_rate == 0: loss = pred_criterion(pred, target) elif cfg.mask_learn_rate == 0: loss = mask_criterion(heatmap, mask) else: loss = (1 - cfg.mask_learn_rate) * pred_criterion( pred, target) + cfg.mask_learn_rate * mask_criterion( heatmap, mask) epoch_loss += loss.item() loss.backward() optimizer.step() train_loss = epoch_loss / len(train_data_loader) scheduler.step(train_loss) print("[{}] Training - loss: {:.4e}".format(epoch, train_loss)) if cfg.visible: log_writer.add_scalar('Train/Loss', train_loss, epoch) log_writer.add_scalar('Train/lr', optimizer.param_groups[0]['lr'], epoch) # val if epoch % 5 == 0 or cfg.debug: if cfg.model.split('_')[0] == 'MobileNetV3': model.train() else: model.eval() with torch.no_grad(): val_pred_loss = 0 scores = np.zeros((1)) prediction = np.zeros((1)) for img, mask, target in tqdm( val_data_loader, desc='[{}] val_batch'.format(epoch), bar_format= '{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'): img = img.to(device) mask = mask.to(device) target = target.to(device) pred, heatmap = model(img) val_pred_loss += nn.functional.mse_loss(pred, target, reduction='sum') scores = np.append(scores, target.cpu().numpy().reshape((-1))) prediction = np.append(prediction, pred.cpu().numpy().reshape((-1))) val_pred_loss = val_pred_loss / len(val_data) prediction = np.nan_to_num(prediction) srocc = stats.spearmanr(prediction[1:], scores[1:])[0] lcc = stats.pearsonr(prediction[1:], scores[1:])[0] print("[{}] Val - MSE: {:.4e}".format(epoch, val_pred_loss)) print("[{}] Val - LCC: {:.4f}, SROCC: {:.4f}".format( epoch, lcc, srocc)) if cfg.visible: idx = np.random.randint(0, mask.shape[0]) heatmap_s = torch.softmax(heatmap, 1)[idx, 1, :, :] log_writer.add_scalar('Val/MSE', val_pred_loss, epoch) log_writer.add_scalar('Val/LCC', lcc, epoch) log_writer.add_scalar('Val/SROCC', srocc, epoch) log_writer.add_image('Val/img', img[idx], epoch) log_writer.add_image('Val/mask', torch.squeeze(mask[idx]), epoch, dataformats='HW') log_writer.add_image('Val/heatmap', torch.squeeze(heatmap_s), epoch, dataformats='HW') # checkpoint if cfg.cp_num > 0: # model.cpu() cp_name = "{}_{:.4e}.pth".format(epoch, train_loss) if epoch < cfg.cp_num + 1: best_cp.append([cp_name, train_loss]) best_cp.sort(key=lambda x: x[1]) best_cp_path = os.path.normcase( os.path.join(cp_dir_path, cp_name)) cp_data = dict( cfg=str(cfg), model=model.state_dict(), ) torch.save(cp_data, best_cp_path) else: if train_loss < best_cp[-1][1]: os.remove( os.path.normcase( os.path.join(cp_dir_path, best_cp[-1][0]))) best_cp[-1] = [cp_name, train_loss] best_cp.sort(key=lambda x: x[1]) best_cp_path = os.path.normcase( os.path.join(cp_dir_path, cp_name)) cp_data = dict( cfg=str(cfg), model=model.state_dict(), ) torch.save(cp_data, best_cp_path) if ((log_interval > 0) and (epoch % log_interval == 0 or epoch % 100 == 0)) or \ (epoch == cfg.max_epochs): history_cp_path = os.path.normcase( os.path.join(history_dir_path, cp_name)) cp_data = dict( cfg=str(cfg), model=model.state_dict(), ) torch.save(cp_data, history_cp_path) # model.to(device) return model.cpu()