def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # Model ####################################### lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load(os.path.join(config.load_path, "arch_%d.pt"%arch_idx)) else: state = torch.load(os.path.join(config.load_path, "arch_%d_%d.pt"%(arch_idx, int(config.load_epoch)))) model = Network( [state["alpha_%d_0"%arch_idx].detach(), state["alpha_%d_1"%arch_idx].detach(), state["alpha_%d_2"%arch_idx].detach()], [None, state["beta_%d_1"%arch_idx].detach(), state["beta_%d_2"%arch_idx].detach()], [state["ratio_%d_0"%arch_idx].detach(), state["ratio_%d_1"%arch_idx].detach(), state["ratio_%d_2"%arch_idx].detach()], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx==0) mIoU02 = state["mIoU02"]; latency02 = state["latency02"]; obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"]; latency12 = state["latency12"]; obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d"%b), getattr(model, "path%d"%b), width=getattr(model, "widths%d"%b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join(config.save, "ops_%d_%d.png"%(arch_idx,b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d"%b), getattr(model, "path%d"%b), F_base=config.Fch).savefig(os.path.join(config.save, "ops_%d_%d.png"%(arch_idx,b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig(os.path.join(config.save, "path_width%d.png"%arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig(os.path.join(config.save, "path_width_all%d.png"%arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048),), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) model = model.cuda() ##################################################### print(config.save) latency = compute_latency(model, (1, 3, config.image_height, config.image_width)) logging.info("FPS:" + str(1000./latency))
def main(): width_mult_list = [ 4. / 12, 6. / 12, 8. / 12, 10. / 12, 1., ] json_files = glob.glob(os.path.join(args.arch_loc, "*.json")) for json_file in json_files: with open(json_file, 'r') as f: model_dict = json.loads(f.read()) last = model_dict["lasts"] save_dir = os.path.join(args.save_dir, os.path.basename(json_file).strip('.json')) os.makedirs(save_dir, exist_ok=True) try: for b in range(len(last)): if len(width_mult_list) > 1: plot_op(model_dict["ops"][b], model_dict["paths"][b], width=model_dict["widths"][b], head_width=args.stem_head_width, F_base=args.Fch).savefig(os.path.join( save_dir, "ops_%d_%d.png" % (0, b)), bbox_inches="tight") else: plot_op(model_dict["ops"][b], model_dict["paths"][b], F_base=args.Fch).savefig(os.path.join( save_dir, "ops_%d_%d.png" % (0, b)), bbox_inches="tight") plot_path_width(model_dict["lasts"], model_dict["paths"], model_dict["widths"]).savefig( os.path.join(save_dir, "path_width%d.png" % 0)) except: print("Arch: {} is invalid".format(json_file)) shutil.rmtree(save_dir)
def arch_logging(model, args, logger, epoch): input_size = (1, 3, 1024, 2048) net = Network_Multi_Path_Infer( [getattr(model, model._arch_names[model.arch_idx]["alphas"][0]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["alphas"][1]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["alphas"][2]).clone().detach()], [None, getattr(model, model._arch_names[model.arch_idx]["betas"][0]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["betas"][1]).clone().detach()], [getattr(model, model._arch_names[model.arch_idx]["ratios"][0]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["ratios"][1]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["ratios"][2]).clone().detach()], num_classes=model._num_classes, layers=model._layers, Fch=model._Fch, width_mult_list=model._width_mult_list, stem_head_width=model._stem_head_width[model.arch_idx]) plot_op(net.ops0, net.path0, F_base=args.Fch).savefig("table.png", bbox_inches="tight") logger.add_image("arch/ops0_arch%d"%model.arch_idx, np.swapaxes(np.swapaxes(plt.imread("table.png"), 0, 2), 1, 2), epoch) plot_op(net.ops1, net.path1, F_base=args.Fch).savefig("table.png", bbox_inches="tight") logger.add_image("arch/ops1_arch%d"%model.arch_idx, np.swapaxes(np.swapaxes(plt.imread("table.png"), 0, 2), 1, 2), epoch) plot_op(net.ops2, net.path2, F_base=args.Fch).savefig("table.png", bbox_inches="tight") logger.add_image("arch/ops2_arch%d"%model.arch_idx, np.swapaxes(np.swapaxes(plt.imread("table.png"), 0, 2), 1, 2), epoch) net.build_structure([2, 0]) net = net.cuda() net.eval() latency0, _ = net.forward_latency(input_size[1:]) logger.add_scalar("arch/fps0_arch%d"%model.arch_idx, 1000./latency0, epoch) logger.add_figure("arch/path_width_arch%d_02"%model.arch_idx, plot_path_width([2, 0], [net.path2, net.path0], [net.widths2, net.widths0]), epoch) net.build_structure([2, 1]) net = net.cuda() net.eval() latency1, _ = net.forward_latency(input_size[1:]) logger.add_scalar("arch/fps1_arch%d"%model.arch_idx, 1000./latency1, epoch) logger.add_figure("arch/path_width_arch%d_12"%model.arch_idx, plot_path_width([2, 1], [net.path2, net.path1], [net.widths2, net.widths1]), epoch) return 1000./latency0, 1000./latency1
def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) # data loader ########################### data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling } # Model ####################################### models = [] evaluators = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) # logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), ), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=os.path.join( config.save, 'predictions'), show_image=True, show_prediction=True) evaluators.append(evaluator) models.append(model) # Cityscapes ########################################### logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) with torch.no_grad(): # validation print("[validation...]") valid_mIoUs = infer(models, evaluators, logger=None) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx]))
def main(): create_exp_dir(config.save, scripts_to_save=glob.glob('*.py') + glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling**2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) distill_criterion = nn.KLDivLoss() # data loader ########################### if config.is_test: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_eval_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } else: data_setting = { 'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'test_source': config.test_source, 'down_sampling': config.down_sampling } train_loader = get_train_loader(config, Cityscapes, test=config.is_test) # Model ####################################### models = [] evaluators = [] testers = [] lasts = [] for idx, arch_idx in enumerate(config.arch_idx): if config.load_epoch == "last": state = torch.load( os.path.join(config.load_path, "arch_%d.pt" % arch_idx)) else: state = torch.load( os.path.join( config.load_path, "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch)))) model = Network([ state["alpha_%d_0" % arch_idx].detach(), state["alpha_%d_1" % arch_idx].detach(), state["alpha_%d_2" % arch_idx].detach() ], [ None, state["beta_%d_1" % arch_idx].detach(), state["beta_%d_2" % arch_idx].detach() ], [ state["ratio_%d_0" % arch_idx].detach(), state["ratio_%d_1" % arch_idx].detach(), state["ratio_%d_2" % arch_idx].detach() ], num_classes=config.num_classes, layers=config.layers, Fch=config.Fch, width_mult_list=config.width_mult_list, stem_head_width=config.stem_head_width[idx], ignore_skip=arch_idx == 0) mIoU02 = state["mIoU02"] latency02 = state["latency02"] obj02 = objective_acc_lat(mIoU02, latency02) mIoU12 = state["mIoU12"] latency12 = state["latency12"] obj12 = objective_acc_lat(mIoU12, latency12) if obj02 > obj12: last = [2, 0] else: last = [2, 1] lasts.append(last) model.build_structure(last) logging.info("net: " + str(model)) for b in last: if len(config.width_mult_list) > 1: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), width=getattr(model, "widths%d" % b), head_width=config.stem_head_width[idx][1], F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") else: plot_op(getattr(model, "ops%d" % b), getattr(model, "path%d" % b), F_base=config.Fch).savefig(os.path.join( config.save, "ops_%d_%d.png" % (arch_idx, b)), bbox_inches="tight") plot_path_width(model.lasts, model.paths, model.widths).savefig( os.path.join(config.save, "path_width%d.png" % arch_idx)) plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0], [model.widths2, model.widths1, model.widths0]).savefig( os.path.join(config.save, "path_width_all%d.png" % arch_idx)) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048), )) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) logging.info("ops:" + str(model.ops)) logging.info("path:" + str(model.paths)) logging.info("last:" + str(model.lasts)) model = model.cuda() init_weight(model, nn.init.kaiming_normal_, torch.nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') if arch_idx == 0 and len(config.arch_idx) > 1: partial = torch.load( os.path.join(config.teacher_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) elif config.is_eval: partial = torch.load( os.path.join(config.eval_path, "weights%d.pt" % arch_idx)) state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state} state.update(pretrained_dict) model.load_state_dict(state) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) evaluators.append(evaluator) tester = SegTester(Cityscapes(data_setting, 'test', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, out_idx=0, config=config, verbose=False, save_path=None, show_image=False) testers.append(tester) # Optimizer ################################### base_lr = config.lr if arch_idx == 1 or len(config.arch_idx) == 1: # optimize teacher solo OR student (w. distill from teacher) optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) models.append(model) # Cityscapes ########################################### if config.is_eval: logging.info(config.load_path) logging.info(config.eval_path) logging.info(config.save) # validation print("[validation...]") with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) exit(0) tbar = tqdm(range(config.nepochs), ncols=80) for epoch in tbar: logging.info(config.load_path) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train_mIoUs = train(train_loader, models, ohem_criterion, distill_criterion, optimizer, logger, epoch) torch.cuda.empty_cache() for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx], epoch) logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx])) else: logger.add_scalar("mIoU/train_student", train_mIoUs[idx], epoch) logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx])) adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1, config.nepochs) # validation if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0): tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): valid_mIoUs = infer(models, evaluators, logger) for idx, arch_idx in enumerate(config.arch_idx): if arch_idx == 0: logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], epoch) logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx])) else: logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], epoch) logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx])) save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx)) # test if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0: tbar.set_description("[Epoch %d/%d][test...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): test(epoch, models, testers, logger) for idx, arch_idx in enumerate(config.arch_idx): save(models[idx], os.path.join(config.save, "weights%d.pt" % arch_idx))