def main(pretrain=True): config.save = 'ckpt/{}'.format(config.save) create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) assert type(pretrain) == bool or type(pretrain) == str update_arch = True if pretrain == True: update_arch = False logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # Model ####################################### model = Network(config.layers, slimmable=config.slimmable, width_mult_list=config.width_mult_list, width_mult_list_sh=config.width_mult_list_sh, loss_weight=config.loss_weight, prun_modes=config.prun_modes, quantize=config.quantize) model = torch.nn.DataParallel(model).cuda() # print(model) # teacher_model = Generator(3, 3) # teacher_model.load_state_dict(torch.load(config.generator_A2B)) # teacher_model = torch.nn.DataParallel(teacher_model).cuda() # for param in teacher_model.parameters(): # param.require_grads = False if type(pretrain) == str: partial = torch.load(pretrain + "/weights.pt") state = model.state_dict() pretrained_dict = { k: v for k, v in partial.items() if k in state and state[k].size() == partial[k].size() } state.update(pretrained_dict) model.load_state_dict(state) # else: # features = [model.module.stem, model.module.cells, model.module.header] # init_weight(features, nn.init.kaiming_normal_, nn.InstanceNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') architect = Architect(model, config) # Optimizer ################################### base_lr = config.lr parameters = [] parameters += list(model.module.stem.parameters()) parameters += list(model.module.cells.parameters()) parameters += list(model.module.header.parameters()) if config.opt == 'Adam': optimizer = torch.optim.Adam(parameters, lr=base_lr, betas=config.betas) elif config.opt == 'Sgd': optimizer = torch.optim.SGD(parameters, lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) else: logging.info("Wrong Optimizer Type.") sys.exit() # lr policy ############################## total_iteration = config.nepochs * config.niters_per_epoch if config.lr_schedule == 'linear': lr_policy = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=LambdaLR(config.nepochs, 0, config.decay_epoch).step) elif config.lr_schedule == 'exponential': lr_policy = torch.optim.lr_scheduler.ExponentialLR( optimizer, config.lr_decay) elif config.lr_schedule == 'multistep': lr_policy = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=config.milestones, gamma=config.gamma) else: logging.info("Wrong Learning Rate Schedule Type.") sys.exit() # data loader ########################### transforms_ = [ # transforms.Resize(int(config.image_height*1.12), Image.BICUBIC), # transforms.RandomCrop(config.image_height), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] # train_loader_model = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, unaligned=True, portion=config.train_portion), # batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) # train_loader_arch = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, unaligned=True, portion=config.train_portion-1), # batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) train_loader_model = DataLoader(PairedImageDataset( config.dataset_path, config.target_path, transforms_=transforms_, portion=config.train_portion), batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) train_loader_arch = DataLoader(PairedImageDataset( config.dataset_path, config.target_path, transforms_=transforms_, portion=config.train_portion - 1), batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) transforms_ = [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] test_loader = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, mode='test'), batch_size=1, shuffle=False, num_workers=config.num_workers) tbar = tqdm(range(config.nepochs), ncols=80) valid_fid_history = [] flops_history = [] flops_supernet_history = [] best_fid = 1000 best_epoch = 0 for epoch in tbar: logging.info(pretrain) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) logging.info("update arch: " + str(update_arch)) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train(pretrain, train_loader_model, train_loader_arch, model, architect, optimizer, lr_policy, logger, epoch, update_arch=update_arch) torch.cuda.empty_cache() lr_policy.step() # validation if epoch and not (epoch + 1) % config.eval_epoch: tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) save(model, os.path.join(config.save, 'weights_%d.pt' % epoch)) with torch.no_grad(): if pretrain == True: model.module.prun_mode = "min" valid_fid = infer(epoch, model, test_loader, logger) logger.add_scalar('fid/val_min', valid_fid, epoch) logging.info("Epoch %d: valid_fid_min %.3f" % (epoch, valid_fid)) if len(model.module._width_mult_list) > 1: model.module.prun_mode = "max" valid_fid = infer(epoch, model, test_loader, logger) logger.add_scalar('fid/val_max', valid_fid, epoch) logging.info("Epoch %d: valid_fid_max %.3f" % (epoch, valid_fid)) model.module.prun_mode = "random" valid_fid = infer(epoch, model, test_loader, logger) logger.add_scalar('fid/val_random', valid_fid, epoch) logging.info("Epoch %d: valid_fid_random %.3f" % (epoch, valid_fid)) else: model.module.prun_mode = None valid_fid, flops = infer(epoch, model, test_loader, logger, finalize=True) logger.add_scalar('fid/val', valid_fid, epoch) logging.info("Epoch %d: valid_fid %.3f" % (epoch, valid_fid)) logger.add_scalar('flops/val', flops, epoch) logging.info("Epoch %d: flops %.3f" % (epoch, flops)) valid_fid_history.append(valid_fid) flops_history.append(flops) if update_arch: flops_supernet_history.append(architect.flops_supernet) if valid_fid < best_fid: best_fid = valid_fid best_epoch = epoch logging.info("Best fid:%.3f, Best epoch:%d" % (best_fid, best_epoch)) if update_arch: state = {} state['alpha'] = getattr(model.module, 'alpha') state['beta'] = getattr(model.module, 'beta') state['ratio'] = getattr(model.module, 'ratio') state['beta_sh'] = getattr(model.module, 'beta_sh') state['ratio_sh'] = getattr(model.module, 'ratio_sh') state["fid"] = valid_fid state["flops"] = flops torch.save( state, os.path.join(config.save, "arch_%d_%f.pt" % (epoch, flops))) if config.flops_weight > 0: if flops < config.flops_min: architect.flops_weight /= 2 elif flops > config.flops_max: architect.flops_weight *= 2 logger.add_scalar("arch/flops_weight", architect.flops_weight, epoch + 1) logging.info("arch_flops_weight = " + str(architect.flops_weight)) save(model, os.path.join(config.save, 'weights.pt')) if update_arch: torch.save(state, os.path.join(config.save, "arch.pt"))
def main(): config.save = 'ckpt/{}'.format(config.save) create_exp_dir(config.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) state = torch.load(os.path.join(config.load_path, 'arch.pt')) # Model ####################################### model = NAS_GAN_Infer(state['alpha'], state['beta'], state['ratio'], num_cell=config.num_cell, op_per_cell=config.op_per_cell, width_mult_list=config.width_mult_list, loss_weight=config.loss_weight, loss_func=config.loss_func, before_act=config.before_act, quantize=config.quantize) flops, params = profile(model, inputs=(torch.randn(1, 3, 510, 350),), custom_ops=custom_ops) flops = model.forward_flops(size=(3, 510, 350)) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) model = torch.nn.DataParallel(model).cuda() if type(config.pretrain) == str: state_dict = torch.load(config.pretrain) model.load_state_dict(state_dict) # else: # features = [model.module.cells, model.module.conv_first, model.module.trunk_conv, model.module.upconv1, # model.module.upconv2, model.module.HRconv, model.module.conv_last] # init_weight(features, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') teacher_model = RRDBNet(3, 3, 64, 23, gc=32) teacher_model.load_state_dict(torch.load(config.generator_A2B), strict=True) teacher_model = torch.nn.DataParallel(teacher_model).cuda() teacher_model.eval() for param in teacher_model.parameters(): param.require_grads = False # Optimizer ################################### base_lr = config.lr parameters = [] parameters += list(model.module.cells.parameters()) parameters += list(model.module.conv_first.parameters()) parameters += list(model.module.trunk_conv.parameters()) parameters += list(model.module.upconv1.parameters()) parameters += list(model.module.upconv2.parameters()) parameters += list(model.module.HRconv.parameters()) parameters += list(model.module.conv_last.parameters()) if config.opt == 'Adam': optimizer = torch.optim.Adam( parameters, lr=base_lr, betas=config.betas) elif config.opt == 'Sgd': optimizer = torch.optim.SGD( parameters, lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) else: logging.info("Wrong Optimizer Type.") sys.exit() # lr policy ############################## total_iteration = config.nepochs * config.niters_per_epoch if config.lr_schedule == 'linear': lr_policy = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=LambdaLR(config.nepochs, 0, config.decay_epoch).step) elif config.lr_schedule == 'exponential': lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, config.lr_decay) elif config.lr_schedule == 'multistep': lr_policy = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma) else: logging.info("Wrong Learning Rate Schedule Type.") sys.exit() # data loader ############################ transforms_ = [ transforms.RandomCrop(config.image_height), transforms.RandomHorizontalFlip(), transforms.ToTensor()] train_loader_model = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, unaligned=True), batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers) transforms_ = [ transforms.ToTensor()] test_loader = DataLoader(ImageDataset(config.dataset_path, transforms_=transforms_, mode='val'), batch_size=1, shuffle=False, num_workers=config.num_workers) if config.eval_only: logging.info('Eval: psnr = %f', infer(0, model, test_loader, logger)) sys.exit(0) tbar = tqdm(range(config.nepochs), ncols=80) for epoch in tbar: logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train(train_loader_model, model, teacher_model, optimizer, lr_policy, logger, epoch) torch.cuda.empty_cache() lr_policy.step() # validation if epoch and not (epoch+1) % config.eval_epoch: tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): model.prun_mode = None valid_psnr = infer(epoch, model, test_loader, logger) logger.add_scalar('psnr/val', valid_psnr, epoch) logging.info("Epoch %d: valid_psnr %.3f"%(epoch, valid_psnr)) logger.add_scalar('flops/val', flops, epoch) logging.info("Epoch %d: flops %.3f"%(epoch, flops)) save(model, os.path.join(config.save, 'weights_%d.pt'%epoch)) save(model, os.path.join(config.save, 'weights.pt'))
def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) distill_criterion = nn.KLDivLoss() # data loader ########################### if config.is_test: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_eval_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } else: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } train_loader = get_train_loader(config, Cityscapes, test=config.is_test) # Model ####################################### models = [] evaluators = [] testers = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), )) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') if arch_idx == 0 and len(config.arch_idx) > 1: partial = torch.load( os.path.join(config.teacher_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) elif config.is_eval: partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) evaluators.append(evaluator) tester = SegTester(Cityscapes(data_setting, 'test', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) testers.append(tester) # Optimizer ################################### base_lr = config.lr if arch_idx == 1 or len(config.arch_idx) == 1: # optimize teacher solo OR student (w. distill from teacher) optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) models.append(model) # Cityscapes ########################################### if config.is_eval: logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) # validation print("[validation...]") with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) exit(0) tbar = tqdm(range(config.nepochs), ncols=80) for epoch in tbar: logging.info(config.load_path) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train_mIoUs = train(train_loader, models, ohem_criterion, distill_criterion, optimizer, logger, epoch) torch.cuda.empty_cache() for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx], epoch) logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx])) else: logger.add_scalar("mIoU/train_student", train_mIoUs[idx], epoch) logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx])) adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1, config.nepochs) # validation if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0): tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], epoch) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], epoch) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx)) # test if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0: tbar.set_description("[Epoch %d/%d][test...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): test(epoch, models, testers, logger) for idx, arch_idx in enumerate(config.arch_idx): save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx))
def main(pretrain=True): config.save = 'search-{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) assert type(pretrain) == bool or type(pretrain) == str update_arch = True if pretrain == True: update_arch = False logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) # Model ####################################### model = Network(config.num_classes, config.layers, ohem_criterion, Fch=config.Fch, width_mult_list=config.width_mult_list, prun_modes=config.prun_modes, stem_head_width=config.stem_head_width) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) model = model.cuda() if type(pretrain) == str: partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0') state = model.state_dict() pretrained_dict = { k: v for k, v in partial.items() if k in state and state[k].size() == partial[k].size() } state.update(pretrained_dict) model.load_state_dict(state) else: init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) architect = Architect(model, config) # Optimizer ################################### base_lr = config.lr parameters = [] parameters += list(model.stem.parameters()) parameters += list(model.cells.parameters()) parameters += list(model.refine32.parameters()) parameters += list(model.refine16.parameters()) parameters += list(model.head0.parameters()) parameters += list(model.head1.parameters()) parameters += list(model.head2.parameters()) parameters += list(model.head02.parameters()) parameters += list(model.head12.parameters()) optimizer = torch.optim.SGD(parameters, lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) # lr policy ############################## lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978) # data loader ########################### data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling } train_loader_model = get_train_loader(config, EGTEA, portion=config.train_portion) train_loader_arch = get_train_loader(config, EGTEA, portion=config.train_portion - 1) evaluator = SegEvaluator(EGTEA(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, config=config, verbose=False, save_path=None, show_image=False) if update_arch: for idx in range(len(config.latency_weight)): logger.add_scalar("arch/latency_weight%d" % idx, config.latency_weight[idx], 0) logging.info("arch_latency_weight%d = " % idx + str(config.latency_weight[idx])) tbar = tqdm(range(config.nepochs), ncols=80) valid_mIoU_history = [] FPSs_history = [] latency_supernet_history = [] latency_weight_history = [] valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"] arch_names = {0: "teacher", 1: "student"} for epoch in tbar: logging.info(pretrain) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) logging.info("update arch: " + str(update_arch)) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train(pretrain, train_loader_model, train_loader_arch, model, architect, ohem_criterion, optimizer, lr_policy, logger, epoch, update_arch=update_arch) torch.cuda.empty_cache() lr_policy.step() # validation tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): if pretrain == True: model.prun_mode = "min" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_min_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_min_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) if len(model._width_mult_list) > 1: model.prun_mode = "max" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_max_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_max_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) model.prun_mode = "random" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar( 'mIoU/val_random_%s' % valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_random_%s %.3f" % (epoch, valid_names[i], valid_mIoUs[i])) else: valid_mIoUss = [] FPSs = [] model.prun_mode = None for idx in range(len(model._arch_names)): # arch_idx model.arch_idx = idx valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator, logger) valid_mIoUss.append(valid_mIoUs) FPSs.append([fps0, fps1]) for i in range(5): # preds logger.add_scalar( 'mIoU/val_%s_%s' % (arch_names[idx], valid_names[i]), valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_%s_%s %.3f" % (epoch, arch_names[idx], valid_names[i], valid_mIoUs[i])) if config.latency_weight[idx] > 0: logger.add_scalar( 'Objective/val_%s_8s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0), epoch) logging.info( "Epoch %d: Objective_%s_8s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000. / fps0))) logger.add_scalar( 'Objective/val_%s_16s_32s' % arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1), epoch) logging.info( "Epoch %d: Objective_%s_16s_32s %.3f" % (epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000. / fps1))) valid_mIoU_history.append(valid_mIoUss) FPSs_history.append(FPSs) if update_arch: latency_supernet_history.append(architect.latency_supernet) latency_weight_history.append(architect.latency_weight) save(model, os.path.join(config.save, 'weights.pt')) if type(pretrain) == str: # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios} for idx, arch_name in enumerate(model._arch_names): state = {} for name in arch_name['alphas']: state[name] = getattr(model, name) for name in arch_name['betas']: state[name] = getattr(model, name) for name in arch_name['ratios']: state[name] = getattr(model, name) state["mIoU02"] = valid_mIoUs[3] state["mIoU12"] = valid_mIoUs[4] if pretrain is not True: state["latency02"] = 1000. / fps0 state["latency12"] = 1000. / fps1 torch.save( state, os.path.join(config.save, "arch_%d_%d.pt" % (idx, epoch))) torch.save(state, os.path.join(config.save, "arch_%d.pt" % (idx))) if update_arch: for idx in range(len(config.latency_weight)): if config.latency_weight[idx] > 0: if (int(FPSs[idx][0] >= config.FPS_max[idx]) + int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1: architect.latency_weight[idx] /= 2 elif (int(FPSs[idx][0] <= config.FPS_min[idx]) + int(FPSs[idx][1] <= config.FPS_min[idx])) > 0: architect.latency_weight[idx] *= 2 logger.add_scalar( "arch/latency_weight_%s" % arch_names[idx], architect.latency_weight[idx], epoch + 1) logging.info("arch_latency_weight_%s = " % arch_names[idx] + str(architect.latency_weight[idx]))