def train(loader, model, optimizer, epoch, args): timer = Timer() data_time = AverageMeter() loss_meter = AverageMeter() ce_loss_meter = AverageMeter() cur_lr = adjust_learning_rate(args.lr_decay_rate, optimizer, epoch) model.train() optimizer.zero_grad() ce_loss_criterion = nn.CrossEntropyLoss() for i, (input, meta) in tqdm(enumerate(loader), desc="Train Epoch"): if args.debug and i >= debug_short_train_num: break data_time.update(timer.thetime() - timer.end) _batch_size = len(meta) target = [] for _ in range(_batch_size): target.extend(meta[_]["labels"]) target = torch.from_numpy(np.array(target)) input = input.view( _batch_size * 3, input.shape[2], input.shape[3], input.shape[4], input.shape[5], ) metric_feat, output = model(input) ce_loss = ce_loss_criterion(output.cuda(), target.long().cuda()) loss = ce_loss loss.backward() loss_meter.update(loss.item()) ce_loss_meter.update(ce_loss.item()) if i % args.accum_grad == args.accum_grad - 1: optimizer.step() optimizer.zero_grad() if i % args.print_freq == 0 and i > 0: logger.info("[{0}][{1}/{2}]\t" "Dataload_Time={data_time.avg:.3f}\t" "Loss={loss.avg:.4f}\t" "CELoss={ce_loss.avg:.4f}\t" "LR={cur_lr:.7f}\t" "bestAP={ap:.3f}".format( epoch, i, len(loader), data_time=data_time, loss=loss_meter, ce_loss=ce_loss_meter, ap=args.best_score, cur_lr=cur_lr, )) loss_meter.reset() ce_loss_meter.reset()
def optimizer_summary(optim_list): if not isinstance(optim_list, list): optim_list = [optim_list] from operator import mul for optim in optim_list: assert isinstance( optim, torch.optim.Optimizer), ValueError("must be an Optimizer instance") data = [] param_num = 0 for group_id, param_group in enumerate(optim.param_groups): lr = param_group["lr"] weight_decay = param_group["weight_decay"] for id, param in enumerate(param_group["params"]): requires_grad = param.requires_grad shape = list(param.data.size()) param_num += reduce(mul, shape, 1) data.append( [group_id, id, shape, lr, weight_decay, requires_grad]) table = tabulate( data, headers=[ "group", "id", "shape", "lr", "weight_decay", "requires_grad", ], ) logger.info( colored( "Optimizer Summary, Optimzer Parameters: #param={} \n".format( param_num), "cyan", ) + table)
def model_summary(model_list): if not isinstance(model_list, list): model_list = [model_list] from operator import mul for model in model_list: state_dict = model.state_dict().copy() params = filter(lambda p: p.requires_grad, model.parameters()) data = [] param_num = 0 for key, value in state_dict.items(): data.append([key, list(value.size())]) param_num += reduce(mul, list(value.size()), 1) table = tabulate(data, headers=["name", "shape"]) logger.info( colored( "Model Summary, Arg Parameters: #param={} \n".format( param_num), "cyan", ) + table) logger.info(model)
def main(): args = parse() set_gpu(args.gpu) args.best_score = 0 args.best_result_dict = {} from dataloader_baseline import get_my_dataset train_loader = get_my_dataset(args) args.semantic_mem = train_loader.dataset.semantic_mem seed(args.manual_seed) model = get_model(args) if args.evaluate: logger.info(vars(args)) assert args.test_load is not None saved_dict = torch.load(args.test_load) logger.warning("loading weight {}".format(args.test_load)) model.load_state_dict(saved_dict["state_dict"], strict=True) args.read_cache_feat = True score_dict = do_eval(args=args, model=model) return logger.warning("using {}".format(args.optimizer)) if args.optimizer == "sgd": optimizer = torch.optim.SGD( model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, ) elif args.optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.wd) else: assert False, "invalid optimizer" model_summary(model) optimizer_summary(optimizer) logger.info(vars(args)) for epoch in range(args.epochs): if args.method == "baseline": train(train_loader, model, optimizer, epoch, args) elif args.method == "va": train_va(train_loader, model, optimizer, epoch, args) elif args.method == "vasa": train_vasa(train_loader, model, optimizer, epoch, args) elif args.method == "ranking": train_ranking(train_loader, model, optimizer, epoch, args) else: raise if epoch % eval_per_epoch == 0 or epoch == args.epochs - 1: score_dict = do_eval(args=args, model=model) score = score_dict["ap"] is_best = score > args.best_score if is_best: # args.best_result_dict = score_dict args.best_score = max(score, args.best_score) logger.warning("saving best snapshot..") torch.save( { "epoch": epoch, "state_dict": model.state_dict(), "score": args.best_score, "optimizer": optimizer.state_dict(), }, os.path.join(logger.get_logger_dir(), "best.pth.tar"), ) weigth_path = os.path.join(logger.get_logger_dir(), "best.pth.tar") saved_dict = torch.load(weigth_path) logger.warning("loading weight {}, best validation result={}".format( weigth_path, saved_dict["score"])) model.load_state_dict(saved_dict["state_dict"], strict=True) args.eval_split = "testing" args.eval_all = True logger.info(vars(args)) score_dict = do_eval(args=args, model=model) logger.info("training finish. snapshot weight in {}".format( logger.get_logger_dir()))