optim.optimizer.zero_grad() losses.backward() with timer.counter('update'): optim.optimizer.step() time_this = time.time() if i > start_iter: batch_time = time_this - time_last timer.add_batch_time(batch_time) time_last = time_this if i > start_iter and i % 20 == 0 and main_gpu: cur_lr = optim.optimizer.param_groups[0]['lr'] time_name = ['batch', 'data', 'for+loss', 'backward', 'update'] t_t, t_d, t_fl, t_b, t_u = timer.get_times(time_name) seconds = (max_iter - i) * t_t eta = str(datetime.timedelta(seconds=seconds)).split('.')[0] print(f'step: {i} | lr: {cur_lr:.2e} | l_class: {l_c:.3f} | l_box: {l_b:.3f} | l_iou: {l_iou:.3f} | ' f't_t: {t_t:.3f} | t_d: {t_d:.3f} | t_fl: {t_fl:.3f} | t_b: {t_b:.3f} | t_u: {t_u:.3f} | ETA: {eta}') if main_gpu and (i > start_iter and i % cfg.val_interval == 0 or i == max_iter): # pay attention to the logic here checkpointer.save(cur_iter=i) inference(model.module, cfg, during_training=True) model.train() timer.reset() # training time and val time share the same Obj, so reset it to avoid confusion if main_gpu and i != 1 and i % cfg.val_interval == 1: timer.start() # the first iter after validation should not be included
def train(cfg): device = torch.device(cfg.DEVICE) arguments = {} arguments["epoch"] = 0 if not cfg.DATALOADER.BENCHMARK: model = Modelbuilder(cfg) print(model) model.to(device) model.float() optimizer, scheduler = make_optimizer(cfg, model) checkpointer = Checkpointer(model=model, optimizer=optimizer, scheduler=scheduler, save_dir=cfg.OUTPUT_DIR) extra_checkpoint_data = checkpointer.load( cfg.WEIGHTS, prefix=cfg.WEIGHTS_PREFIX, prefix_replace=cfg.WEIGHTS_PREFIX_REPLACE, loadoptimizer=cfg.WEIGHTS_LOAD_OPT) arguments.update(extra_checkpoint_data) model.train() logger = setup_logger("trainer", cfg.FOLDER_NAME) if cfg.TENSORBOARD.USE: writer = SummaryWriter(cfg.FOLDER_NAME) else: writer = None meters = MetricLogger(writer=writer) start_training_time = time.time() end = time.time() start_epoch = arguments["epoch"] max_epoch = cfg.SOLVER.MAX_EPOCHS if start_epoch == max_epoch: logger.info("Final model exists! No need to train!") test(cfg, model) return data_loader = make_data_loader( cfg, is_train=True, ) size_epoch = len(data_loader) max_iter = size_epoch * max_epoch logger.info("Start training {} batches/epoch".format(size_epoch)) for epoch in range(start_epoch, max_epoch): arguments["epoch"] = epoch #batchcnt = 0 for iteration, batchdata in enumerate(data_loader): cur_iter = size_epoch * epoch + iteration data_time = time.time() - end batchdata = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items() } if not cfg.DATALOADER.BENCHMARK: loss_dict, metric_dict = model(batchdata) # print(loss_dict, metric_dict) optimizer.zero_grad() loss_dict['loss'].backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time, iteration=cur_iter) if cfg.DATALOADER.BENCHMARK: logger.info( meters.delimiter.join([ "iter: {iter}", "{meters}", ]).format( iter=iteration, meters=str(meters), )) continue eta_seconds = meters.time.global_avg * (max_iter - cur_iter) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % cfg.LOG_FREQ == 0: meters.update(iteration=cur_iter, **loss_dict) meters.update(iteration=cur_iter, **metric_dict) logger.info( meters.delimiter.join([ "eta: {eta}", "epoch: {epoch}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", # "max mem: {memory:.0f}", ]).format( eta=eta_string, epoch=epoch, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], # memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, )) #UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate scheduler.step() if (epoch + 1) % cfg.SOLVER.CHECKPOINT_PERIOD == 0: arguments["epoch"] += 1 checkpointer.save("model_{:03d}".format(epoch), **arguments) if epoch == max_epoch - 1: arguments['epoch'] = max_epoch checkpointer.save("model_final", **arguments) total_training_time = time.time() - start_training_time total_time_str = str( datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / epoch)".format( total_time_str, total_training_time / (max_epoch - start_epoch))) if epoch == max_epoch - 1 or ((epoch + 1) % cfg.EVAL_FREQ == 0): results = test(cfg, model) meters.update(is_train=False, iteration=cur_iter, **results)
accs.append(acc) # remember best prec@1 and save checkpoint is_best = accs[0] > checkpointer.best_acc if is_best: checkpointer.best_acc = accs[0] elif cfg.OPTIM.VAL and cfg.OPTIM.OPT in ['sgd', 'qhm', 'salsa']: logging.info("DROPPING LEARNING RATE") # Anneal the learning rate if no improvement has been seen in the validation dataset. for group in optimizer.param_groups: group['lr'] = group['lr'] * 1.0 / cfg.OPTIM.DROP_FACTOR if cfg.OPTIM.OPT in ['salsa']: optimizer.state['switched'] = True logging.info("Switch due to overfiting!") checkpointer.epoch = epoch + 1 checkpointer.save(is_best) # exactly evaluate the best checkpoint # wait for all processes to complete before calculating the score synchronize() best_model_path = os.path.join(checkpointer.save_dir, "model_best.pth") if os.path.isfile(best_model_path): logging.info( "Evaluating the best checkpoint: {}".format(best_model_path)) cfg.defrost() cfg.EVALUATE = True checkpointer.is_test = True cfg.freeze() extra_checkpoint_data = checkpointer.load(best_model_path) for task_name, testloader, test_meter in zip(task_names, testloaders, test_meters):
def train(args): try: model = nets[args.net](args.margin, args.omega, args.use_hardtriplet) model.to(args.device) except Exception as e: logger.error("Initialize {} error: {}".format(args.net, e)) return logger.info("Training {}.".format(args.net)) optimizer = make_optimizer(args, model) scheduler = make_scheduler(args, optimizer) if args.device != torch.device("cpu"): amp_opt_level = 'O1' if args.use_amp else 'O0' model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level) arguments = {} arguments.update(vars(args)) arguments["itr"] = 0 checkpointer = Checkpointer(model, optimizer=optimizer, scheduler=scheduler, save_dir=args.out_dir, save_to_disk=True) ## load model from pretrained_weights or training break_point. extra_checkpoint_data = checkpointer.load(args.pretrained_weights) arguments.update(extra_checkpoint_data) batch_size = args.batch_size fashion = FashionDataset(item_num=args.iteration_num*batch_size) dataloader = DataLoader(dataset=fashion, shuffle=True, num_workers=8, batch_size=batch_size) model.train() meters = MetricLogger(delimiter=", ") max_itr = len(dataloader) start_itr = arguments["itr"] + 1 itr_start_time = time.time() training_start_time = time.time() for itr, batch_data in enumerate(dataloader, start_itr): batch_data = (bd.to(args.device) for bd in batch_data) loss_dict = model.loss(*batch_data) optimizer.zero_grad() if args.device != torch.device("cpu"): with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_losses: scaled_losses.backward() else: loss_dict["loss"].backward() optimizer.step() scheduler.step() arguments["itr"] = itr meters.update(**loss_dict) itr_time = time.time() - itr_start_time itr_start_time = time.time() meters.update(itr_time=itr_time) if itr % 50 == 0: eta_seconds = meters.itr_time.global_avg * (max_itr - itr) eta = str(datetime.timedelta(seconds=int(eta_seconds))) logger.info( meters.delimiter.join( [ "itr: {itr}/{max_itr}", "lr: {lr:.7f}", "{meters}", "eta: {eta}\n", ] ).format( itr=itr, lr=optimizer.param_groups[0]["lr"], max_itr=max_itr, meters=str(meters), eta=eta, ) ) ## save model if itr % args.checkpoint_period == 0: checkpointer.save("model_{:07d}".format(itr), **arguments) if itr == max_itr: checkpointer.save("model_final", **arguments) break training_time = time.time() - training_start_time training_time = str(datetime.timedelta(seconds=int(training_time))) logger.info("total training time: {}".format(training_time))
def train(cfg, args): train_set = DatasetCatalog.get(cfg.DATASETS.TRAIN, args) val_set = DatasetCatalog.get(cfg.DATASETS.VAL, args) train_loader = DataLoader(train_set, cfg.SOLVER.IMS_PER_BATCH, num_workers=cfg.DATALOADER.NUM_WORKERS, shuffle=True) val_loader = DataLoader(val_set, cfg.SOLVER.IMS_PER_BATCH, num_workers=cfg.DATALOADER.NUM_WORKERS, shuffle=True) gpu_ids = [_ for _ in range(torch.cuda.device_count())] model = build_model(cfg) model.to("cuda") model = torch.nn.parallel.DataParallel( model, gpu_ids) if not args.debug else model logger = logging.getLogger("train_logger") logger.info("Start training") train_metrics = MetricLogger(delimiter=" ") max_iter = cfg.SOLVER.MAX_ITER output_dir = cfg.OUTPUT_DIR optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) checkpointer = Checkpointer(model, optimizer, scheduler, output_dir, logger) start_iteration = checkpointer.load() if not args.debug else 0 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD validation_period = cfg.SOLVER.VALIDATION_PERIOD summary_writer = SummaryWriter(log_dir=os.path.join(output_dir, "summary")) visualizer = train_set.visualizer(cfg.VISUALIZATION)(summary_writer) model.train() start_training_time = time.time() last_batch_time = time.time() for iteration, inputs in enumerate(cycle(train_loader), start_iteration): data_time = time.time() - last_batch_time iteration = iteration + 1 scheduler.step() inputs = to_cuda(inputs) outputs = model(inputs) loss_dict = gather_loss_dict(outputs) loss = loss_dict["loss"] train_metrics.update(**loss_dict) optimizer.zero_grad() loss.backward() optimizer.step() batch_time = time.time() - last_batch_time last_batch_time = time.time() train_metrics.update(time=batch_time, data=data_time) eta_seconds = train_metrics.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iter: logger.info( train_metrics.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}" ]).format(eta=eta_string, iter=iteration, meters=str(train_metrics), lr=optimizer.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)) summary_writer.add_scalars("train", train_metrics.mean, iteration) if iteration % 100 == 0: visualizer.visualize(inputs, outputs, iteration) if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration)) if iteration % validation_period == 0: with torch.no_grad(): val_metrics = MetricLogger(delimiter=" ") for i, inputs in enumerate(val_loader): data_time = time.time() - last_batch_time inputs = to_cuda(inputs) outputs = model(inputs) loss_dict = gather_loss_dict(outputs) val_metrics.update(**loss_dict) batch_time = time.time() - last_batch_time last_batch_time = time.time() val_metrics.update(time=batch_time, data=data_time) if i % 20 == 0 or i == cfg.SOLVER.VALIDATION_LIMIT: logger.info( val_metrics.delimiter.join([ "VALIDATION", "eta: {eta}", "iter: {iter}", "{meters}" ]).format(eta=eta_string, iter=iteration, meters=str(val_metrics))) if i == cfg.SOLVER.VALIDATION_LIMIT: summary_writer.add_scalars("val", val_metrics.mean, iteration) break if iteration == max_iter: break checkpointer.save("model_{:07d}".format(max_iter)) total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter)))
def train(cfg, output_dir=""): # logger = logging.getLogger("ModelZoo.trainer") # build model set_random_seed(cfg.RNG_SEED) model, loss_fn, metric_fn = build_model(cfg) logger.info("Build model:\n{}".format(str(model))) model = nn.DataParallel(model).cuda() # build optimizer optimizer = build_optimizer(cfg, model) # build lr scheduler scheduler = build_scheduler(cfg, optimizer) # build checkpointer checkpointer = Checkpointer(model, optimizer=optimizer, scheduler=scheduler, save_dir=output_dir, logger=logger) checkpoint_data = checkpointer.load(cfg.GLOBAL.TRAIN.WEIGHT, resume=cfg.AUTO_RESUME) ckpt_period = cfg.GLOBAL.TRAIN.CHECKPOINT_PERIOD # build data loader train_data_loader = build_data_loader(cfg, cfg.GLOBAL.DATASET, mode="train") val_period = cfg.GLOBAL.VAL.VAL_PERIOD # val_data_loader = build_data_loader(cfg, mode="val") if val_period > 0 else None # build tensorboard logger (optionally by comment) tensorboard_logger = TensorboardLogger(output_dir) # train max_epoch = cfg.GLOBAL.MAX_EPOCH start_epoch = checkpoint_data.get("epoch", 0) # best_metric_name = "best_{}".format(cfg.TRAIN.VAL_METRIC) # best_metric = checkpoint_data.get(best_metric_name, None) logger.info("Start training from epoch {}".format(start_epoch)) for epoch in range(start_epoch, max_epoch): cur_epoch = epoch + 1 scheduler.step() start_time = time.time() train_meters = train_model( model, loss_fn, metric_fn, data_loader=train_data_loader, optimizer=optimizer, curr_epoch=epoch, tensorboard_logger=tensorboard_logger, log_period=cfg.GLOBAL.TRAIN.LOG_PERIOD, output_dir=output_dir, ) epoch_time = time.time() - start_time logger.info("Epoch[{}]-Train {} total_time: {:.2f}s".format( cur_epoch, train_meters.summary_str, epoch_time)) # checkpoint if cur_epoch % ckpt_period == 0 or cur_epoch == max_epoch: checkpoint_data["epoch"] = cur_epoch # checkpoint_data[best_metric_name] = best_metric checkpointer.save("model_{:03d}".format(cur_epoch), **checkpoint_data) ''' # validate if val_period < 1: continue if cur_epoch % val_period == 0 or cur_epoch == max_epoch: val_meters = validate_model(model, loss_fn, metric_fn, image_scales=cfg.MODEL.VAL.IMG_SCALES, inter_scales=cfg.MODEL.VAL.INTER_SCALES, isFlow=(cur_epoch > cfg.SCHEDULER.INIT_EPOCH), data_loader=val_data_loader, curr_epoch=epoch, tensorboard_logger=tensorboard_logger, log_period=cfg.TEST.LOG_PERIOD, output_dir=output_dir, ) logger.info("Epoch[{}]-Val {}".format(cur_epoch, val_meters.summary_str)) # best validation cur_metric = val_meters.meters[cfg.TRAIN.VAL_METRIC].global_avg if best_metric is None or cur_metric > best_metric: best_metric = cur_metric checkpoint_data["epoch"] = cur_epoch checkpoint_data[best_metric_name] = best_metric checkpointer.save("model_best", **checkpoint_data) ''' logger.info("Train Finish!") # logger.info("Best val-{} = {}".format(cfg.TRAIN.VAL_METRIC, best_metric)) return model
def test(cfg, model=None): torch.cuda.empty_cache() # TODO check if it helps cpu_device = torch.device("cpu") if cfg.VIS.FLOPS: # device = cpu_device device = torch.device("cuda:0") else: device = torch.device(cfg.DEVICE) if model is None: # load model from outputs model = Modelbuilder(cfg) model.to(device) checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR) _ = checkpointer.load(cfg.WEIGHTS) data_loaders = make_data_loader(cfg, is_train=False) if cfg.VIS.FLOPS: model.eval() from thop import profile for idx, batchdata in enumerate(data_loaders[0]): with torch.no_grad(): flops, params = profile( model, inputs=({ k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items() }, False)) print('flops', flops, 'params', params) exit() if cfg.TEST.RECOMPUTE_BN: tmp_data_loader = make_data_loader(cfg, is_train=True, dataset_list=cfg.DATASETS.TEST) model.train() for idx, batchdata in enumerate(tqdm(tmp_data_loader)): with torch.no_grad(): model( { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items() }, is_train=True) #cnt = 0 #while cnt < 1000: # for idx, batchdata in enumerate(tqdm(tmp_data_loader)): # with torch.no_grad(): # model({k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items()}, is_train=True) # cnt += 1 checkpointer.save("model_bn") model.eval() elif cfg.TEST.TRAIN_BN: model.train() else: model.eval() dataset_names = cfg.DATASETS.TEST meters = MetricLogger() #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR: # all_preds = np.zeros((len(data_loaders), cfg.KEYPOINT.NUM_PTS, 3), dtype=np.float32) cpu = lambda x: x.to(cpu_device).numpy() if isinstance(x, torch.Tensor ) else x logger = setup_logger("tester", cfg.OUTPUT_DIR) for data_loader, dataset_name in zip(data_loaders, dataset_names): print('Loading ', dataset_name) dataset = data_loader.dataset logger.info("Start evaluation on {} dataset({} images).".format( dataset_name, len(dataset))) total_timer = Timer() total_timer.tic() predictions = [] #if 'h36m' in cfg.OUTPUT_DIR: # err_joints = 0 #else: err_joints = np.zeros((cfg.TEST.IMS_PER_BATCH, int(cfg.TEST.MAX_TH))) total_joints = 0 for idx, batchdata in enumerate(tqdm(data_loader)): if cfg.VIS.VIDEO and not 'h36m' in cfg.OUTPUT_DIR: for k, v in batchdata.items(): try: #good 1 2 3 4 5 6 7 8 12 16 30 # 4 17.4 vs 16.5 # 30 41.83200 vs 40.17562 #bad 0 22 #0 43.78544 vs 45.24059 #22 43.01385 vs 43.88636 vis_idx = 16 batchdata[k] = v[:, vis_idx, None] except: pass if cfg.VIS.VIDEO_GT: for k, v in batchdata.items(): try: vis_idx = 30 batchdata[k] = v[:, vis_idx:vis_idx + 2] except: pass joints = cpu(batchdata['points-2d'].squeeze())[0] orig_img = de_transform( cpu(batchdata['img'].squeeze()[None, ...])[0][0]) # fig = plt.figure() # ax = fig.add_subplot(111) ax = display_image_in_actual_size(orig_img.shape[1], orig_img.shape[2]) if 'h36m' in cfg.OUTPUT_DIR: draw_2d_pose(joints, ax) orig_img = orig_img[::-1] else: visibility = cpu(batchdata['visibility'].squeeze())[0] plot_two_hand_2d(joints, ax, visibility) # plot_two_hand_2d(joints, ax) ax.imshow(orig_img.transpose((1, 2, 0))) ax.axis('off') output_folder = os.path.join("outs", "video_gt", dataset_name) mkdir(output_folder) plt.savefig(os.path.join(output_folder, "%08d" % idx), bbox_inches="tight", pad_inches=0) plt.cla() plt.clf() plt.close() continue #print('batchdatapoints-3d', batchdata['points-3d']) batch_size = cfg.TEST.IMS_PER_BATCH with torch.no_grad(): loss_dict, metric_dict, output = model( { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items() }, is_train=False) meters.update(**prefix_dict(loss_dict, dataset_name)) meters.update(**prefix_dict(metric_dict, dataset_name)) # udpate err_joints if cfg.VIS.VIDEO: joints = cpu(output['batch_locs'].squeeze()) if joints.shape[0] == 1: joints = joints[0] try: orig_img = de_transform( cpu(batchdata['img'].squeeze()[None, ...])[0][0]) except: orig_img = de_transform( cpu(batchdata['img'].squeeze()[None, ...]) [0]) # fig = plt.figure() # ax = fig.add_subplot(111) ax = display_image_in_actual_size(orig_img.shape[1], orig_img.shape[2]) if 'h36m' in cfg.OUTPUT_DIR: draw_2d_pose(joints, ax) orig_img = orig_img[::-1] else: visibility = cpu(batchdata['visibility'].squeeze()) if visibility.shape[0] == 1: visibility = visibility[0] plot_two_hand_2d(joints, ax, visibility) ax.imshow(orig_img.transpose((1, 2, 0))) ax.axis('off') output_folder = os.path.join(cfg.OUTPUT_DIR, "video", dataset_name) mkdir(output_folder) plt.savefig(os.path.join(output_folder, "%08d" % idx), bbox_inches="tight", pad_inches=0) plt.cla() plt.clf() plt.close() # plt.show() if cfg.TEST.PCK and cfg.DOTEST: #if 'h36m' in cfg.OUTPUT_DIR: # err_joints += metric_dict['accuracy'] * output['total_joints'] # total_joints += output['total_joints'] # # all_preds #else: for i in range(batch_size): err_joints = np.add(err_joints, output['err_joints']) total_joints += sum(output['total_joints']) if idx % cfg.VIS.SAVE_PRED_FREQ == 0 and ( cfg.VIS.SAVE_PRED_LIMIT == -1 or idx < cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ): # print(meters) for i in range(batch_size): predictions.append(( { k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in batchdata.items() }, { k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in output.items() }, )) if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx > cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ: break # if not cfg.DOTRAIN and cfg.SAVE_PRED: # if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx < cfg.VIS.SAVE_PRED_LIMIT: # for i in range(batch_size): # predictions.append( # ( # {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in batchdata.items()}, # {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in output.items()}, # ) # ) # if idx == cfg.VIS.SAVE_PRED_LIMIT: # break #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR: # logger.info('accuracy0.5: {}'.format(err_joints/total_joints)) # dataset.evaluate(all_preds) # name_value, perf_indicator = dataset.evaluate(all_preds) # names = name_value.keys() # values = name_value.values() # num_values = len(name_value) # logger.info(' '.join(['| {}'.format(name) for name in names]) + ' |') # logger.info('|---' * (num_values) + '|') # logger.info(' '.join(['| {:.3f}'.format(value) for value in values]) + ' |') total_time = total_timer.toc() total_time_str = get_time_str(total_time) logger.info("Total run time: {} ".format(total_time_str)) if cfg.OUTPUT_DIR: #and cfg.VIS.SAVE_PRED: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) mkdir(output_folder) torch.save(predictions, os.path.join(output_folder, cfg.VIS.SAVE_PRED_NAME)) if cfg.DOTEST and cfg.TEST.PCK: print(err_joints.shape) torch.save(err_joints * 1.0 / total_joints, os.path.join(output_folder, "pck.pth")) logger.info("{}".format(str(meters))) model.train() return meters.get_all_avg()