Esempio n. 1
0
        optim.optimizer.zero_grad()
        losses.backward()

    with timer.counter('update'):
        optim.optimizer.step()

    time_this = time.time()
    if i > start_iter:
        batch_time = time_this - time_last
        timer.add_batch_time(batch_time)
    time_last = time_this

    if i > start_iter and i % 20 == 0 and main_gpu:
        cur_lr = optim.optimizer.param_groups[0]['lr']
        time_name = ['batch', 'data', 'for+loss', 'backward', 'update']
        t_t, t_d, t_fl, t_b, t_u = timer.get_times(time_name)
        seconds = (max_iter - i) * t_t
        eta = str(datetime.timedelta(seconds=seconds)).split('.')[0]

        print(f'step: {i} | lr: {cur_lr:.2e} | l_class: {l_c:.3f} | l_box: {l_b:.3f} | l_iou: {l_iou:.3f} | '
              f't_t: {t_t:.3f} | t_d: {t_d:.3f} | t_fl: {t_fl:.3f} | t_b: {t_b:.3f} | t_u: {t_u:.3f} | ETA: {eta}')

    if main_gpu and (i > start_iter and i % cfg.val_interval == 0 or i == max_iter):  # pay attention to the logic here
        checkpointer.save(cur_iter=i)
        inference(model.module, cfg, during_training=True)
        model.train()
        timer.reset()  # training time and val time share the same Obj, so reset it to avoid confusion

    if main_gpu and i != 1 and i % cfg.val_interval == 1:
        timer.start()  # the first iter after validation should not be included
Esempio n. 2
0
def train(cfg):
    device = torch.device(cfg.DEVICE)
    arguments = {}
    arguments["epoch"] = 0
    if not cfg.DATALOADER.BENCHMARK:
        model = Modelbuilder(cfg)
        print(model)
        model.to(device)
        model.float()
        optimizer, scheduler = make_optimizer(cfg, model)
        checkpointer = Checkpointer(model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    save_dir=cfg.OUTPUT_DIR)
        extra_checkpoint_data = checkpointer.load(
            cfg.WEIGHTS,
            prefix=cfg.WEIGHTS_PREFIX,
            prefix_replace=cfg.WEIGHTS_PREFIX_REPLACE,
            loadoptimizer=cfg.WEIGHTS_LOAD_OPT)
        arguments.update(extra_checkpoint_data)
        model.train()

    logger = setup_logger("trainer", cfg.FOLDER_NAME)
    if cfg.TENSORBOARD.USE:
        writer = SummaryWriter(cfg.FOLDER_NAME)
    else:
        writer = None
    meters = MetricLogger(writer=writer)
    start_training_time = time.time()
    end = time.time()
    start_epoch = arguments["epoch"]
    max_epoch = cfg.SOLVER.MAX_EPOCHS

    if start_epoch == max_epoch:
        logger.info("Final model exists! No need to train!")
        test(cfg, model)
        return

    data_loader = make_data_loader(
        cfg,
        is_train=True,
    )
    size_epoch = len(data_loader)
    max_iter = size_epoch * max_epoch
    logger.info("Start training {} batches/epoch".format(size_epoch))

    for epoch in range(start_epoch, max_epoch):
        arguments["epoch"] = epoch
        #batchcnt = 0
        for iteration, batchdata in enumerate(data_loader):
            cur_iter = size_epoch * epoch + iteration
            data_time = time.time() - end

            batchdata = {
                k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in batchdata.items()
            }

            if not cfg.DATALOADER.BENCHMARK:
                loss_dict, metric_dict = model(batchdata)
                # print(loss_dict, metric_dict)
                optimizer.zero_grad()
                loss_dict['loss'].backward()
                optimizer.step()

            batch_time = time.time() - end
            end = time.time()

            meters.update(time=batch_time, data=data_time, iteration=cur_iter)

            if cfg.DATALOADER.BENCHMARK:
                logger.info(
                    meters.delimiter.join([
                        "iter: {iter}",
                        "{meters}",
                    ]).format(
                        iter=iteration,
                        meters=str(meters),
                    ))
                continue

            eta_seconds = meters.time.global_avg * (max_iter - cur_iter)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % cfg.LOG_FREQ == 0:
                meters.update(iteration=cur_iter, **loss_dict)
                meters.update(iteration=cur_iter, **metric_dict)
                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "epoch: {epoch}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        # "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        epoch=epoch,
                        iter=iteration,
                        meters=str(meters),
                        lr=optimizer.param_groups[0]["lr"],
                        # memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                    ))
        #UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
        scheduler.step()

        if (epoch + 1) % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
            arguments["epoch"] += 1
            checkpointer.save("model_{:03d}".format(epoch), **arguments)
        if epoch == max_epoch - 1:
            arguments['epoch'] = max_epoch
            checkpointer.save("model_final", **arguments)

            total_training_time = time.time() - start_training_time
            total_time_str = str(
                datetime.timedelta(seconds=total_training_time))
            logger.info("Total training time: {} ({:.4f} s / epoch)".format(
                total_time_str,
                total_training_time / (max_epoch - start_epoch)))
        if epoch == max_epoch - 1 or ((epoch + 1) % cfg.EVAL_FREQ == 0):
            results = test(cfg, model)
            meters.update(is_train=False, iteration=cur_iter, **results)
Esempio n. 3
0
            accs.append(acc)

        # remember best prec@1 and save checkpoint
        is_best = accs[0] > checkpointer.best_acc
        if is_best:
            checkpointer.best_acc = accs[0]
        elif cfg.OPTIM.VAL and cfg.OPTIM.OPT in ['sgd', 'qhm', 'salsa']:
            logging.info("DROPPING LEARNING RATE")
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            for group in optimizer.param_groups:
                group['lr'] = group['lr'] * 1.0 / cfg.OPTIM.DROP_FACTOR
            if cfg.OPTIM.OPT in ['salsa']:
                optimizer.state['switched'] = True
                logging.info("Switch due to overfiting!")
        checkpointer.epoch = epoch + 1
        checkpointer.save(is_best)

    # exactly evaluate the best checkpoint
    # wait for all processes to complete before calculating the score
    synchronize()
    best_model_path = os.path.join(checkpointer.save_dir, "model_best.pth")
    if os.path.isfile(best_model_path):
        logging.info(
            "Evaluating the best checkpoint: {}".format(best_model_path))
        cfg.defrost()
        cfg.EVALUATE = True
        checkpointer.is_test = True
        cfg.freeze()
        extra_checkpoint_data = checkpointer.load(best_model_path)
        for task_name, testloader, test_meter in zip(task_names, testloaders,
                                                     test_meters):
Esempio n. 4
0
def train(args):
    try:
        model = nets[args.net](args.margin, args.omega, args.use_hardtriplet)
        model.to(args.device)
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return
    logger.info("Training {}.".format(args.net))

    optimizer = make_optimizer(args, model)
    scheduler = make_scheduler(args, optimizer)

    if args.device != torch.device("cpu"):
        amp_opt_level = 'O1' if args.use_amp else 'O0'
        model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

    arguments = {}
    arguments.update(vars(args))
    arguments["itr"] = 0
    checkpointer = Checkpointer(model, 
                                optimizer=optimizer, 
                                scheduler=scheduler,
                                save_dir=args.out_dir, 
                                save_to_disk=True)
    ## load model from pretrained_weights or training break_point.
    extra_checkpoint_data = checkpointer.load(args.pretrained_weights)
    arguments.update(extra_checkpoint_data)
    
    batch_size = args.batch_size
    fashion = FashionDataset(item_num=args.iteration_num*batch_size)
    dataloader = DataLoader(dataset=fashion, shuffle=True, num_workers=8, batch_size=batch_size)

    model.train()
    meters = MetricLogger(delimiter=", ")
    max_itr = len(dataloader)
    start_itr = arguments["itr"] + 1
    itr_start_time = time.time()
    training_start_time = time.time()
    for itr, batch_data in enumerate(dataloader, start_itr):
        batch_data = (bd.to(args.device) for bd in batch_data)
        loss_dict = model.loss(*batch_data)
        optimizer.zero_grad()
        if args.device != torch.device("cpu"):
            with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_losses:
                scaled_losses.backward()
        else:
            loss_dict["loss"].backward()
        optimizer.step()
        scheduler.step()

        arguments["itr"] = itr
        meters.update(**loss_dict)
        itr_time = time.time() - itr_start_time
        itr_start_time = time.time()
        meters.update(itr_time=itr_time)
        if itr % 50 == 0:
            eta_seconds = meters.itr_time.global_avg * (max_itr - itr)
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))
            logger.info(
                meters.delimiter.join(
                    [
                        "itr: {itr}/{max_itr}",
                        "lr: {lr:.7f}",
                        "{meters}",
                        "eta: {eta}\n",
                    ]
                ).format(
                    itr=itr,
                    lr=optimizer.param_groups[0]["lr"],
                    max_itr=max_itr,
                    meters=str(meters),
                    eta=eta,
                )
            )

        ## save model
        if itr % args.checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(itr), **arguments)
        if itr == max_itr:
            checkpointer.save("model_final", **arguments)
            break

    training_time = time.time() - training_start_time
    training_time = str(datetime.timedelta(seconds=int(training_time)))
    logger.info("total training time: {}".format(training_time))
def train(cfg, args):
    train_set = DatasetCatalog.get(cfg.DATASETS.TRAIN, args)
    val_set = DatasetCatalog.get(cfg.DATASETS.VAL, args)
    train_loader = DataLoader(train_set,
                              cfg.SOLVER.IMS_PER_BATCH,
                              num_workers=cfg.DATALOADER.NUM_WORKERS,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            cfg.SOLVER.IMS_PER_BATCH,
                            num_workers=cfg.DATALOADER.NUM_WORKERS,
                            shuffle=True)

    gpu_ids = [_ for _ in range(torch.cuda.device_count())]
    model = build_model(cfg)
    model.to("cuda")
    model = torch.nn.parallel.DataParallel(
        model, gpu_ids) if not args.debug else model

    logger = logging.getLogger("train_logger")
    logger.info("Start training")
    train_metrics = MetricLogger(delimiter="  ")
    max_iter = cfg.SOLVER.MAX_ITER
    output_dir = cfg.OUTPUT_DIR

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)
    checkpointer = Checkpointer(model, optimizer, scheduler, output_dir,
                                logger)
    start_iteration = checkpointer.load() if not args.debug else 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    validation_period = cfg.SOLVER.VALIDATION_PERIOD
    summary_writer = SummaryWriter(log_dir=os.path.join(output_dir, "summary"))
    visualizer = train_set.visualizer(cfg.VISUALIZATION)(summary_writer)

    model.train()
    start_training_time = time.time()
    last_batch_time = time.time()

    for iteration, inputs in enumerate(cycle(train_loader), start_iteration):
        data_time = time.time() - last_batch_time
        iteration = iteration + 1
        scheduler.step()

        inputs = to_cuda(inputs)
        outputs = model(inputs)

        loss_dict = gather_loss_dict(outputs)
        loss = loss_dict["loss"]
        train_metrics.update(**loss_dict)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time = time.time() - last_batch_time
        last_batch_time = time.time()
        train_metrics.update(time=batch_time, data=data_time)

        eta_seconds = train_metrics.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                train_metrics.delimiter.join([
                    "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}",
                    "max mem: {memory:.0f}"
                ]).format(eta=eta_string,
                          iter=iteration,
                          meters=str(train_metrics),
                          lr=optimizer.param_groups[0]["lr"],
                          memory=torch.cuda.max_memory_allocated() / 1024.0 /
                          1024.0))
            summary_writer.add_scalars("train", train_metrics.mean, iteration)

        if iteration % 100 == 0:
            visualizer.visualize(inputs, outputs, iteration)

        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration))

        if iteration % validation_period == 0:
            with torch.no_grad():
                val_metrics = MetricLogger(delimiter="  ")
                for i, inputs in enumerate(val_loader):
                    data_time = time.time() - last_batch_time

                    inputs = to_cuda(inputs)
                    outputs = model(inputs)

                    loss_dict = gather_loss_dict(outputs)
                    val_metrics.update(**loss_dict)

                    batch_time = time.time() - last_batch_time
                    last_batch_time = time.time()
                    val_metrics.update(time=batch_time, data=data_time)

                    if i % 20 == 0 or i == cfg.SOLVER.VALIDATION_LIMIT:
                        logger.info(
                            val_metrics.delimiter.join([
                                "VALIDATION", "eta: {eta}", "iter: {iter}",
                                "{meters}"
                            ]).format(eta=eta_string,
                                      iter=iteration,
                                      meters=str(val_metrics)))

                    if i == cfg.SOLVER.VALIDATION_LIMIT:
                        summary_writer.add_scalars("val", val_metrics.mean,
                                                   iteration)
                        break
        if iteration == max_iter:
            break

    checkpointer.save("model_{:07d}".format(max_iter))
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Esempio n. 6
0
def train(cfg, output_dir=""):
    # logger = logging.getLogger("ModelZoo.trainer")

    # build model
    set_random_seed(cfg.RNG_SEED)
    model, loss_fn, metric_fn = build_model(cfg)
    logger.info("Build model:\n{}".format(str(model)))
    model = nn.DataParallel(model).cuda()

    # build optimizer
    optimizer = build_optimizer(cfg, model)

    # build lr scheduler
    scheduler = build_scheduler(cfg, optimizer)

    # build checkpointer
    checkpointer = Checkpointer(model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                save_dir=output_dir,
                                logger=logger)

    checkpoint_data = checkpointer.load(cfg.GLOBAL.TRAIN.WEIGHT,
                                        resume=cfg.AUTO_RESUME)
    ckpt_period = cfg.GLOBAL.TRAIN.CHECKPOINT_PERIOD

    # build data loader
    train_data_loader = build_data_loader(cfg,
                                          cfg.GLOBAL.DATASET,
                                          mode="train")
    val_period = cfg.GLOBAL.VAL.VAL_PERIOD
    # val_data_loader = build_data_loader(cfg, mode="val") if val_period > 0 else None

    # build tensorboard logger (optionally by comment)
    tensorboard_logger = TensorboardLogger(output_dir)

    # train
    max_epoch = cfg.GLOBAL.MAX_EPOCH
    start_epoch = checkpoint_data.get("epoch", 0)
    # best_metric_name = "best_{}".format(cfg.TRAIN.VAL_METRIC)
    # best_metric = checkpoint_data.get(best_metric_name, None)
    logger.info("Start training from epoch {}".format(start_epoch))
    for epoch in range(start_epoch, max_epoch):
        cur_epoch = epoch + 1
        scheduler.step()
        start_time = time.time()
        train_meters = train_model(
            model,
            loss_fn,
            metric_fn,
            data_loader=train_data_loader,
            optimizer=optimizer,
            curr_epoch=epoch,
            tensorboard_logger=tensorboard_logger,
            log_period=cfg.GLOBAL.TRAIN.LOG_PERIOD,
            output_dir=output_dir,
        )
        epoch_time = time.time() - start_time
        logger.info("Epoch[{}]-Train {}  total_time: {:.2f}s".format(
            cur_epoch, train_meters.summary_str, epoch_time))

        # checkpoint
        if cur_epoch % ckpt_period == 0 or cur_epoch == max_epoch:
            checkpoint_data["epoch"] = cur_epoch
            # checkpoint_data[best_metric_name] = best_metric
            checkpointer.save("model_{:03d}".format(cur_epoch),
                              **checkpoint_data)
        '''
        # validate
        if val_period < 1:
            continue
        if cur_epoch % val_period == 0 or cur_epoch == max_epoch:
            val_meters = validate_model(model,
                                        loss_fn,
                                        metric_fn,
                                        image_scales=cfg.MODEL.VAL.IMG_SCALES,
                                        inter_scales=cfg.MODEL.VAL.INTER_SCALES,
                                        isFlow=(cur_epoch > cfg.SCHEDULER.INIT_EPOCH),
                                        data_loader=val_data_loader,
                                        curr_epoch=epoch,
                                        tensorboard_logger=tensorboard_logger,
                                        log_period=cfg.TEST.LOG_PERIOD,
                                        output_dir=output_dir,
                                        )
            logger.info("Epoch[{}]-Val {}".format(cur_epoch, val_meters.summary_str))

            # best validation
            cur_metric = val_meters.meters[cfg.TRAIN.VAL_METRIC].global_avg
            if best_metric is None or cur_metric > best_metric:
                best_metric = cur_metric
                checkpoint_data["epoch"] = cur_epoch
                checkpoint_data[best_metric_name] = best_metric
                checkpointer.save("model_best", **checkpoint_data)
        '''

    logger.info("Train Finish!")
    # logger.info("Best val-{} = {}".format(cfg.TRAIN.VAL_METRIC, best_metric))

    return model
Esempio n. 7
0
def test(cfg, model=None):
    torch.cuda.empty_cache()  # TODO check if it helps
    cpu_device = torch.device("cpu")
    if cfg.VIS.FLOPS:
        # device = cpu_device
        device = torch.device("cuda:0")
    else:
        device = torch.device(cfg.DEVICE)
    if model is None:
        # load model from outputs
        model = Modelbuilder(cfg)
        model.to(device)
        checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR)
        _ = checkpointer.load(cfg.WEIGHTS)
    data_loaders = make_data_loader(cfg, is_train=False)
    if cfg.VIS.FLOPS:
        model.eval()
        from thop import profile
        for idx, batchdata in enumerate(data_loaders[0]):
            with torch.no_grad():
                flops, params = profile(
                    model,
                    inputs=({
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    }, False))
                print('flops', flops, 'params', params)
                exit()
    if cfg.TEST.RECOMPUTE_BN:
        tmp_data_loader = make_data_loader(cfg,
                                           is_train=True,
                                           dataset_list=cfg.DATASETS.TEST)
        model.train()
        for idx, batchdata in enumerate(tqdm(tmp_data_loader)):
            with torch.no_grad():
                model(
                    {
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    },
                    is_train=True)
        #cnt = 0
        #while cnt < 1000:
        #    for idx, batchdata in enumerate(tqdm(tmp_data_loader)):
        #        with torch.no_grad():
        #            model({k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items()}, is_train=True)
        #        cnt += 1
        checkpointer.save("model_bn")
        model.eval()
    elif cfg.TEST.TRAIN_BN:
        model.train()
    else:
        model.eval()
    dataset_names = cfg.DATASETS.TEST
    meters = MetricLogger()

    #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR:
    #    all_preds = np.zeros((len(data_loaders), cfg.KEYPOINT.NUM_PTS, 3), dtype=np.float32)
    cpu = lambda x: x.to(cpu_device).numpy() if isinstance(x, torch.Tensor
                                                           ) else x

    logger = setup_logger("tester", cfg.OUTPUT_DIR)
    for data_loader, dataset_name in zip(data_loaders, dataset_names):
        print('Loading ', dataset_name)
        dataset = data_loader.dataset

        logger.info("Start evaluation on {} dataset({} images).".format(
            dataset_name, len(dataset)))
        total_timer = Timer()
        total_timer.tic()

        predictions = []
        #if 'h36m' in cfg.OUTPUT_DIR:
        #    err_joints = 0
        #else:
        err_joints = np.zeros((cfg.TEST.IMS_PER_BATCH, int(cfg.TEST.MAX_TH)))
        total_joints = 0

        for idx, batchdata in enumerate(tqdm(data_loader)):
            if cfg.VIS.VIDEO and not 'h36m' in cfg.OUTPUT_DIR:
                for k, v in batchdata.items():
                    try:
                        #good 1 2 3 4 5 6 7 8 12 16 30
                        # 4 17.4 vs 16.5
                        # 30 41.83200 vs 40.17562
                        #bad 0 22
                        #0 43.78544 vs 45.24059
                        #22 43.01385 vs 43.88636
                        vis_idx = 16
                        batchdata[k] = v[:, vis_idx, None]
                    except:
                        pass
            if cfg.VIS.VIDEO_GT:
                for k, v in batchdata.items():
                    try:
                        vis_idx = 30
                        batchdata[k] = v[:, vis_idx:vis_idx + 2]
                    except:
                        pass
                joints = cpu(batchdata['points-2d'].squeeze())[0]
                orig_img = de_transform(
                    cpu(batchdata['img'].squeeze()[None, ...])[0][0])
                # fig = plt.figure()
                # ax = fig.add_subplot(111)
                ax = display_image_in_actual_size(orig_img.shape[1],
                                                  orig_img.shape[2])
                if 'h36m' in cfg.OUTPUT_DIR:
                    draw_2d_pose(joints, ax)
                    orig_img = orig_img[::-1]
                else:
                    visibility = cpu(batchdata['visibility'].squeeze())[0]
                    plot_two_hand_2d(joints, ax, visibility)
                    # plot_two_hand_2d(joints, ax)
                ax.imshow(orig_img.transpose((1, 2, 0)))
                ax.axis('off')
                output_folder = os.path.join("outs", "video_gt", dataset_name)
                mkdir(output_folder)
                plt.savefig(os.path.join(output_folder, "%08d" % idx),
                            bbox_inches="tight",
                            pad_inches=0)
                plt.cla()
                plt.clf()
                plt.close()
                continue
            #print('batchdatapoints-3d', batchdata['points-3d'])
            batch_size = cfg.TEST.IMS_PER_BATCH
            with torch.no_grad():
                loss_dict, metric_dict, output = model(
                    {
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    },
                    is_train=False)
            meters.update(**prefix_dict(loss_dict, dataset_name))
            meters.update(**prefix_dict(metric_dict, dataset_name))
            # udpate err_joints
            if cfg.VIS.VIDEO:
                joints = cpu(output['batch_locs'].squeeze())
                if joints.shape[0] == 1:
                    joints = joints[0]
                try:
                    orig_img = de_transform(
                        cpu(batchdata['img'].squeeze()[None, ...])[0][0])
                except:
                    orig_img = de_transform(
                        cpu(batchdata['img'].squeeze()[None, ...])
                        [0])  # fig = plt.figure()
                # ax = fig.add_subplot(111)
                ax = display_image_in_actual_size(orig_img.shape[1],
                                                  orig_img.shape[2])
                if 'h36m' in cfg.OUTPUT_DIR:
                    draw_2d_pose(joints, ax)
                    orig_img = orig_img[::-1]
                else:
                    visibility = cpu(batchdata['visibility'].squeeze())
                    if visibility.shape[0] == 1:
                        visibility = visibility[0]
                    plot_two_hand_2d(joints, ax, visibility)
                ax.imshow(orig_img.transpose((1, 2, 0)))
                ax.axis('off')
                output_folder = os.path.join(cfg.OUTPUT_DIR, "video",
                                             dataset_name)
                mkdir(output_folder)
                plt.savefig(os.path.join(output_folder, "%08d" % idx),
                            bbox_inches="tight",
                            pad_inches=0)
                plt.cla()
                plt.clf()
                plt.close()
                # plt.show()

            if cfg.TEST.PCK and cfg.DOTEST:
                #if 'h36m' in cfg.OUTPUT_DIR:
                #    err_joints += metric_dict['accuracy'] * output['total_joints']
                #    total_joints += output['total_joints']
                #    # all_preds
                #else:
                for i in range(batch_size):
                    err_joints = np.add(err_joints, output['err_joints'])
                    total_joints += sum(output['total_joints'])

            if idx % cfg.VIS.SAVE_PRED_FREQ == 0 and (
                    cfg.VIS.SAVE_PRED_LIMIT == -1
                    or idx < cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ):
                # print(meters)
                for i in range(batch_size):
                    predictions.append((
                        {
                            k: (cpu(v[i]) if not isinstance(v, int) else v)
                            for k, v in batchdata.items()
                        },
                        {
                            k: (cpu(v[i]) if not isinstance(v, int) else v)
                            for k, v in output.items()
                        },
                    ))
            if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx > cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ:
                break

            # if not cfg.DOTRAIN and cfg.SAVE_PRED:
            #     if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx < cfg.VIS.SAVE_PRED_LIMIT:
            #         for i in range(batch_size):
            #             predictions.append(
            #                     (
            #                         {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in batchdata.items()},
            #                         {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in output.items()},
            #                     )
            #             )
            #     if idx == cfg.VIS.SAVE_PRED_LIMIT:
            #         break
        #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR:
        #    logger.info('accuracy0.5: {}'.format(err_joints/total_joints))
        # dataset.evaluate(all_preds)
        # name_value, perf_indicator = dataset.evaluate(all_preds)
        # names = name_value.keys()
        # values = name_value.values()
        # num_values = len(name_value)
        # logger.info(' '.join(['| {}'.format(name) for name in names]) + ' |')
        # logger.info('|---' * (num_values) + '|')
        # logger.info(' '.join(['| {:.3f}'.format(value) for value in values]) + ' |')

        total_time = total_timer.toc()
        total_time_str = get_time_str(total_time)
        logger.info("Total run time: {} ".format(total_time_str))

        if cfg.OUTPUT_DIR:  #and cfg.VIS.SAVE_PRED:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            torch.save(predictions,
                       os.path.join(output_folder, cfg.VIS.SAVE_PRED_NAME))
            if cfg.DOTEST and cfg.TEST.PCK:
                print(err_joints.shape)
                torch.save(err_joints * 1.0 / total_joints,
                           os.path.join(output_folder, "pck.pth"))

    logger.info("{}".format(str(meters)))

    model.train()
    return meters.get_all_avg()