Exemplo n.º 1
0
def main():

    args = arguments.parse_args()
    LOGGER = ConsoleLogger('Finetune', 'train')
    logdir = LOGGER.getLogFolder()
    LOGGER.info(args)
    LOGGER.info(config)

    cudnn.benckmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # ------------------- Data loader -------------------

    data_transform = transforms.Compose([
        trsf.ImageTrsf(),  # normalize
        trsf.Joints3DTrsf(),  # centerize
        trsf.ToTensor()
    ])  # to tensor

    train_data = Mocap(config.dataset.train,
                       SetType.TRAIN,
                       transform=data_transform)
    train_data_loader = DataLoader(train_data,
                                   batch_size=args.batch_size,
                                   shuffle=config.data_loader.shuffle,
                                   num_workers=8)

    # val_data = Mocap(
    #     config.dataset.val,
    #     SetType.VAL,
    #     transform=data_transform)
    # val_data_loader = DataLoader(
    #     val_data,
    #     batch_size=2,
    #     shuffle=config.data_loader.shuffle,
    #     num_workers=8)

    test_data = Mocap(config.dataset.test,
                      SetType.TEST,
                      transform=data_transform)
    test_data_loader = DataLoader(test_data,
                                  batch_size=2,
                                  shuffle=config.data_loader.shuffle,
                                  num_workers=8)

    # ------------------- Model -------------------
    with open('model/model.yaml') as fin:
        model_cfg = edict(yaml.safe_load(fin))
    resnet = pose_resnet.get_pose_net(model_cfg, True)
    Loss2D = HeatmapLoss()  # same as MSELoss()
    autoencoder = encoder_decoder.AutoEncoder(args.batch_norm,
                                              args.denis_activation)
    # LossHeatmapRecon = HeatmapLoss()
    LossHeatmapRecon = HeatmapLossSquare()
    # Loss3D = nn.MSELoss()
    Loss3D = PoseLoss()
    LossLimb = LimbLoss()

    if torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu}")
        resnet = resnet.cuda(device)
        Loss2D = Loss2D.cuda(device)
        autoencoder = autoencoder.cuda(device)
        LossHeatmapRecon.cuda(device)
        Loss3D.cuda(device)
        LossLimb.cuda(device)

    # ------------------- optimizer -------------------
    if args.freeze_2d_model:
        optimizer = optim.Adam(autoencoder.parameters(), lr=args.learning_rate)
    else:
        optimizer = optim.Adam(itertools.chain(resnet.parameters(),
                                               autoencoder.parameters()),
                               lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step_size,
                                          gamma=0.1)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    # ------------------- load model -------------------
    if args.load_model:
        if not os.path.isfile(args.load_model):
            raise ValueError(f"No checkpoint found at {args.load_model}")
        checkpoint = torch.load(args.load_model)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        resnet.load_state_dict(checkpoint['resnet_state_dict'])
        autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    if args.load_2d_model:
        if not os.path.isfile(args.load_2d_model):
            raise ValueError(f"No checkpoint found at {args.load_2d_model}")
        checkpoint = torch.load(args.load_2d_model, map_location=device)
        resnet.load_state_dict(checkpoint['resnet_state_dict'])

    if args.load_3d_model:
        if not os.path.isfile(args.load_3d_model):
            raise ValueError(f"No checkpoint found at {args.load_3d_model}")
        checkpoint = torch.load(args.load_3d_model, map_location=device)
        autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])

    # ------------------- tensorboard -------------------
    train_global_steps = 0
    writer_dict = {
        'writer': SummaryWriter(log_dir=logdir),
        'train_global_steps': train_global_steps
    }

    best_perf = float('inf')
    best_model = False
    # ------------------- run the model -------------------
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            LOGGER.info(f'====Training epoch {epoch}====')
            losses = AverageMeter()
            batch_time = AverageMeter()

            # ------------------- Evaluation -------------------
            eval_body = evaluate.EvalBody()
            eval_upper = evaluate.EvalUpperBody()
            eval_lower = evaluate.EvalLowerBody()

            resnet.train()
            autoencoder.train()

            end = time.time()
            for it, (img, p2d, p3d, heatmap,
                     action) in enumerate(train_data_loader, 0):

                img = img.to(device)
                p3d = p3d.to(device)
                heatmap = heatmap.to(device)

                heatmap2d_hat = resnet(img)  # torch.Size([16, 15, 48, 48])
                p3d_hat, heatmap2d_recon = autoencoder(heatmap2d_hat)

                loss2d = Loss2D(heatmap2d_hat, heatmap).mean()
                loss_recon = LossHeatmapRecon(heatmap2d_recon,
                                              heatmap2d_hat).mean()
                loss_3d = Loss3D(p3d_hat, p3d).mean()
                loss_cos, loss_len = LossLimb(p3d_hat, p3d)
                loss_cos = loss_cos.mean()
                loss_len = loss_len.mean()

                loss = args.lambda_2d * loss2d + args.lambda_recon * loss_recon + args.lambda_3d * loss_3d - args.lambda_cos * loss_cos + args.lambda_len * loss_len

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_time.update(time.time() - end)
                losses.update(loss.item(), img.size(0))

                if it % config.train.PRINT_FREQ == 0:
                    # logging messages
                    msg = 'Epoch: [{0}][{1}/{2}]\t' \
                          'Batch Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                          'Speed {speed:.1f} samples/s\t' \
                          'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
                        epoch, it, len(train_data_loader), batch_time=batch_time,
                        speed=img.size(0) / batch_time.val,  # averaged within batch
                        loss=losses)
                    LOGGER.info(msg)

                    writer = writer_dict['writer']
                    global_steps = writer_dict['train_global_steps']
                    lr = [
                        group['lr']
                        for group in scheduler.optimizer.param_groups
                    ]
                    writer.add_scalar('learning_rate', lr, global_steps)
                    writer.add_scalar('train_loss', losses.val, global_steps)
                    writer.add_scalar('batch_time', batch_time.val,
                                      global_steps)
                    writer.add_scalar('losses/loss_2d', loss2d, global_steps)
                    writer.add_scalar('losses/loss_recon', loss_recon,
                                      global_steps)
                    writer.add_scalar('losses/loss_3d', loss_3d, global_steps)
                    writer.add_scalar('losses/loss_cos', loss_cos,
                                      global_steps)
                    writer.add_scalar('losses/loss_len', loss_len,
                                      global_steps)
                    image_grid = draw2Dpred_and_gt(img, heatmap2d_hat,
                                                   (368, 368))
                    writer.add_image('predicted_heatmaps', image_grid,
                                     global_steps)
                    image_grid_recon = draw2Dpred_and_gt(
                        img, heatmap2d_recon, (368, 368))
                    writer.add_image('reconstructed_heatmaps',
                                     image_grid_recon, global_steps)
                    writer_dict['train_global_steps'] = global_steps + 1

                    # ------------------- evaluation on training data -------------------

                    # Evaluate results using different evaluation metrices
                    y_output = p3d_hat.data.cpu().numpy()
                    y_target = p3d.data.cpu().numpy()

                    eval_body.eval(y_output, y_target, action)
                    eval_upper.eval(y_output, y_target, action)
                    eval_lower.eval(y_output, y_target, action)

                end = time.time()

            # ------------------- Save results -------------------
            checkpoint_dir = os.path.join(logdir, 'checkpoints')
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            LOGGER.info('=> saving checkpoint to {}'.format(checkpoint_dir))
            states = dict()
            states['resnet_state_dict'] = resnet.state_dict()
            states['autoencoder_state_dict'] = autoencoder.state_dict()
            states['optimizer_state_dict'] = optimizer.state_dict()
            states['scheduler'] = scheduler.state_dict()

            torch.save(states,
                       os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar'))

            res = {
                'FullBody': eval_body.get_results(),
                'UpperBody': eval_upper.get_results(),
                'LowerBody': eval_lower.get_results()
            }

            LOGGER.info('===========Evaluation on Train data==========')
            LOGGER.info(pprint.pformat(res))

            # utils_io.write_json(config.eval.output_file, res)

            # ------------------- validation -------------------
            resnet.eval()
            autoencoder.eval()
            val_loss = validate(LOGGER, test_data_loader, resnet, autoencoder,
                                device, epoch)
            if val_loss < best_perf:
                best_perf = val_loss
                best_model = True

            if best_model:
                shutil.copyfile(
                    os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar'),
                    os.path.join(checkpoint_dir, f'model_best.tar'))
                best_model = False

            # scheduler.step(val_loss)
            scheduler.step()
    LOGGER.info('Done.')
Exemplo n.º 2
0
def main():
    """Main"""
    args = parse_args()
    LOGGER.info('Starting demo...')
    device = torch.device(f"cuda:{args.gpu}")
    LOGGER.info(args)

    # ------------------- Data loader -------------------

    data_transform = transforms.Compose([
        trsf.ImageTrsf(),  # normalize
        trsf.Joints3DTrsf(),  # centerize
        trsf.ToTensor()
    ])  # to tensor

    # let's load data from validation set as example
    data = Mocap(config.dataset[args.data],
                 SetType.VAL,
                 transform=data_transform)
    data_loader = DataLoader(data,
                             batch_size=2,
                             shuffle=config.data_loader.shuffle,
                             num_workers=8)

    # ------------------- Model -------------------
    with open('model/model.yaml') as fin:
        model_cfg = edict(yaml.safe_load(fin))
    resnet = pose_resnet.get_pose_net(model_cfg, False)
    resnet.cuda(device)
    if args.load_model:
        if not os.path.isfile(args.load_model):
            raise ValueError(f"No checkpoint found at {args.load_model}")
        checkpoint = torch.load(args.load_model, map_location=device)
        resnet.load_state_dict(checkpoint['resnet_state_dict'])
    else:
        raise ValueError("No checkpoint!")

    resnet.eval()
    Loss2D = nn.MSELoss()

    # ------------------- Read dataset frames -------------------
    losses = AverageMeter()
    with torch.no_grad():
        for it, (img, p2d, p3d, heatmap, action) in enumerate(data_loader):

            print('Iteration: {}'.format(it))
            print('Images: {}'.format(img.shape))
            print('p2ds: {}'.format(p2d.shape))
            print('p3ds: {}'.format(p3d.shape))
            print('Actions: {}'.format(action))

            heatmap = heatmap.to(device)
            img = img.to(device)

            heatmap_hat = resnet(img)
            loss = Loss2D(heatmap_hat, heatmap)
            losses.update(loss.item(), img.size(0))

            # ------------------- visualization -------------------
            if it < 32:
                img_grid = draw2Dpred_and_gt(img, heatmap,
                                             (368, 368))  # tensor
                img_grid = img_grid.numpy().transpose(1, 2, 0)
                cv2.imwrite(os.path.join(LOGGER.logfile_dir, f'gt_{it}.jpg'),
                            img_grid)

                img_grid_hat = draw2Dpred_and_gt(img, heatmap_hat, (368, 368),
                                                 p2d.clone())  # tensor
                img_grid_hat = img_grid_hat.numpy().transpose(1, 2, 0)
                cv2.imwrite(os.path.join(LOGGER.logfile_dir, f'pred_{it}.jpg'),
                            img_grid_hat)

    # ------------------- Save results -------------------

    LOGGER.info('Saving evaluation results...')
Exemplo n.º 3
0
def main():

    args = arguments.parse_args()
    LOGGER = ConsoleLogger('Train2d', 'train')
    logdir = LOGGER.getLogFolder()
    LOGGER.info(args)
    LOGGER.info(config)

    cudnn.benckmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # ------------------- Data loader -------------------

    data_transform = transforms.Compose([
        trsf.ImageTrsf(),  # normalize
        trsf.Joints3DTrsf(),  # centerize
        trsf.ToTensor()
    ])  # to tensor

    train_data = Mocap(config.dataset.train,
                       SetType.TRAIN,
                       transform=data_transform)
    train_data_loader = DataLoader(train_data,
                                   batch_size=args.batch_size,
                                   shuffle=config.data_loader.shuffle,
                                   num_workers=8)

    test_data = Mocap(config.dataset.test,
                      SetType.TEST,
                      transform=data_transform)
    test_data_loader = DataLoader(test_data,
                                  batch_size=2,
                                  shuffle=config.data_loader.shuffle,
                                  num_workers=8)

    # ------------------- Model -------------------
    with open('model/model.yaml') as fin:
        model_cfg = edict(yaml.safe_load(fin))
    resnet = pose_resnet.get_pose_net(model_cfg, True)
    Loss2D = HeatmapLoss()  # same as MSELoss()
    # LossMSE = nn.MSELoss()

    if torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu}")
        resnet = resnet.cuda(device)
        Loss2D = Loss2D.cuda(device)

    # ------------------- optimizer -------------------
    optimizer = optim.Adam(resnet.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step_size,
                                          gamma=0.1)

    # ------------------- load model -------------------
    if args.load_model:
        if not os.path.isfile(args.load_model):
            raise FileNotFoundError(
                f"No checkpoint found at {args.load_model}")
        checkpoint = torch.load(args.load_model)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        resnet.load_state_dict(checkpoint['resnet_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    # ------------------- tensorboard -------------------
    train_global_steps = 0
    writer_dict = {
        'writer': SummaryWriter(log_dir=logdir),
        'train_global_steps': train_global_steps
    }

    best_model = False
    best_perf = float('inf')
    # ------------------- run the model -------------------
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            LOGGER.info(f'====Training epoch {epoch}====')
            losses = AverageMeter()
            batch_time = AverageMeter()

            resnet.train()

            end = time.time()
            for it, (img, p2d, p3d, heatmap,
                     action) in enumerate(train_data_loader, 0):

                img = img.to(device)
                p2d = p2d.to(device)
                p3d = p3d.to(device)
                heatmap = heatmap.to(device)

                heatmap2d_hat = resnet(img)  # torch.Size([16, 15, 48, 48])

                loss2d = Loss2D(heatmap2d_hat, heatmap).mean()
                # loss2d = LossMSE(heatmap, heatmap2d_hat)

                loss = loss2d * args.lambda_2d

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_time.update(time.time() - end)
                losses.update(loss.item() / args.lambda_2d, img.size(0))

                if it % config.train.PRINT_FREQ == 0:
                    # logging messages
                    msg = 'Epoch: [{0}][{1}/{2}]\t' \
                          'Batch Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                          'Speed {speed:.1f} samples/s\t' \
                          'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
                        epoch, it, len(train_data_loader), batch_time=batch_time,
                        speed=img.size(0) / batch_time.val,  # averaged within batch
                        loss=losses)
                    LOGGER.info(msg)

                    writer = writer_dict['writer']
                    global_steps = writer_dict['train_global_steps']
                    lr = [
                        group['lr']
                        for group in scheduler.optimizer.param_groups
                    ]
                    writer.add_scalar('learning_rate', lr, global_steps)
                    writer.add_scalar('train_loss', losses.val, global_steps)
                    writer.add_scalar('batch_time', batch_time.val,
                                      global_steps)
                    image_grid = draw2Dpred_and_gt(img, heatmap2d_hat)
                    writer.add_image('predicted_heatmaps', image_grid,
                                     global_steps)
                    writer_dict['train_global_steps'] = global_steps + 1

                end = time.time()
            scheduler.step()
            # ------------------- Save results -------------------
            checkpoint_dir = os.path.join(logdir, 'checkpoints')
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            LOGGER.info('=> saving checkpoint to {}'.format(checkpoint_dir))
            states = dict()
            states['resnet_state_dict'] = resnet.state_dict()
            states['optimizer_state_dict'] = optimizer.state_dict()
            states['scheduler'] = scheduler.state_dict()
            torch.save(states,
                       os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar'))

            # ------------------- validation -------------------
            resnet.eval()
            val_loss = validate(LOGGER, test_data_loader, resnet, device,
                                epoch)
            if val_loss < best_perf:
                best_perf = val_loss
                best_model = True

            if best_model:
                shutil.copyfile(
                    os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar'),
                    os.path.join(checkpoint_dir, f'model_best.tar'))
                best_model = False

    LOGGER.info('Done.')
Exemplo n.º 4
0
def main():
    """Main"""

    args = arguments.parse_args()
    LOGGER = ConsoleLogger(args.training_type, 'train')
    logdir = LOGGER.getLogFolder()
    LOGGER.info(args)
    LOGGER.info(config)


    cudnn.benckmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED

    # ------------------- Data loader -------------------

    data_transform = transforms.Compose([
        trsf.ImageTrsf(),  # normalize
        trsf.Joints3DTrsf(),  # centerize
        trsf.ToTensor()])  # to tensor

    # training data
    train_data = Mocap(
        config.dataset.train,
        SetType.TRAIN,
        transform=data_transform)
    train_data_loader = DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=config.data_loader.shuffle,
        num_workers=8)

    val_data = Mocap(
        config.dataset.val,
        SetType.VAL,
        transform=data_transform)
    val_data_loader = DataLoader(
        val_data,
        batch_size=2,
        shuffle=config.data_loader.shuffle,
        num_workers=8)

    # ------------------- Model -------------------
    if args.training_type != 'Train3d':
        with open('model/model.yaml') as fin:
            model_cfg = edict(yaml.safe_load(fin))
        resnet = pose_resnet.get_pose_net(model_cfg, True)
        Loss2D = HeatmapLoss()  # same as MSELoss()
        # LossMSE = nn.MSELoss()
    if args.training_type != 'Train2d':
        autoencoder = encoder_decoder.AutoEncoder()

    if torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu}")
        if args.training_type != 'Train3d':
            resnet = resnet.cuda(device)
            Loss2D = Loss2D.cuda(device)
        if args.training_type != 'Train2d':
            autoencoder = autoencoder.cuda(device)

    # ------------------- optimizer -------------------
    if args.training_type == 'Train2d':
        optimizer = optim.Adam(resnet.parameters(), lr=args.learning_rate)
    if args.training_type == 'Train3d':
        optimizer = optim.Adam(autoencoder.parameters(), lr=config.train.learning_rate)
    if args.training_type != 'Finetune':
        optimizer = optim.Adam(itertools.chain(resnet.parameters(), autoencoder.parameters()), lr=config.train.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)

    # ------------------- load model -------------------
    if args.load_model:
        if not os.path.isfile(args.load_model):
            raise ValueError(f"No checkpoint found at {args.load_model}")
        checkpoint = torch.load(args.load_model)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if args.training_type != 'Train3d':
            resnet.load_state_dict(checkpoint['resnet_state_dict'])
        if args.training_type != 'Train2d':
            autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler'])


    # ------------------- tensorboard -------------------
    train_global_steps = 0
    writer_dict = {
        'writer': SummaryWriter(log_dir=logdir),
        'train_global_steps': train_global_steps
    }

    # ------------------- Evaluation -------------------
    if args.training_type != 'Train2d':
        eval_body = evaluate.EvalBody()
        eval_upper = evaluate.EvalUpperBody()
        eval_lower = evaluate.EvalLowerBody()


    best_perf = float('inf')
    best_model = False
    # ------------------- run the model -------------------
    for epoch in range(args.epochs):
        with torch.autograd.set_detect_anomaly(True):
            LOGGER.info(f'====Training epoch {epoch}====')
            losses = AverageMeter()
            batch_time = AverageMeter()

            resnet.train()
            autoencoder.train()

            end = time.time()
            for it, (img, p2d, p3d, heatmap, action) in enumerate(train_data_loader, 0):

                img = img.to(device)
                p2d = p2d.to(device)
                p3d = p3d.to(device)
                heatmap = heatmap.to(device)

                if args.training_type != 'Train3d':
                    heatmap2d_hat = resnet(img)  # torch.Size([16, 15, 48, 48])
                else:
                    heatmap2d_hat = heatmap
                p3d_hat, heatmap2d_recon = autoencoder(heatmap2d_hat)

                loss2d = Loss2D(heatmap, heatmap2d_hat).mean()
                # loss2d = LossMSE(heatmap, heatmap2d_hat)

                if args.training_type == 'Train2d':
                    loss = loss2d
                elif args.training_type == 'Train3d':
                    pass

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_time.update(time.time() - end)
                losses.update(loss.item(), img.size(0))

                if it % config.train.PRINT_FREQ == 0:
                    # logging messages
                    msg = 'Epoch: [{0}][{1}/{2}]\t' \
                          'Batch Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                          'Speed {speed:.1f} samples/s\t' \
                          'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
                        epoch, it, len(train_data_loader), batch_time=batch_time,
                        speed=img.size(0) / batch_time.val,  # averaged within batch
                        loss=losses)
                    LOGGER.info(msg)
                end = time.time()
            scheduler.step()

            # ------------------- validation -------------------

            resnet.eval()
            autoencoder.eval()

            if args.training_type != 'Train2d':
                # Evaluate results using different evaluation metrices
                y_output = p3d_hat.data.cpu().numpy()
                y_target = p3d.data.cpu().numpy()

                eval_body.eval(y_output, y_target, action)
                eval_upper.eval(y_output, y_target, action)
                eval_lower.eval(y_output, y_target, action)


            # ------------------- Save results -------------------
            checkpoint_dir = os.path.join(logdir, 'checkpoints')
            LOGGER.info('=> saving checkpoint to {}'.format(checkpoint_dir))
            states = dict()
            if args.training_type!='Train3d':
                states['resnet_state_dict'] = resnet.state_dict()
            if args.training_type!='Train2d':
                states['autoencoder_state_dict'] = autoencoder.state_dict()
            states['optimizer_state_dict']= optimizer.state_dict()

            torch.save(states, f'checkpoint_{epoch}.tar')
            res = {'FullBody': eval_body.get_results(),
                   'UpperBody': eval_upper.get_results(),
                   'LowerBody': eval_lower.get_results()}

            utils_io.write_json(config.eval.output_file, res)

            LOGGER.info('Done.')
Exemplo n.º 5
0
def main():
    """Main"""
    args = parse_args()
    print('Starting demo...')
    device = torch.device(f"cuda:{args.gpu}")
    LOGGER.info((args))

    # ------------------- Data loader -------------------

    data_transform = transforms.Compose([
        trsf.ImageTrsf(),  # normalize
        trsf.Joints3DTrsf(),  # centerize
        trsf.ToTensor()
    ])  # to tensor

    data = Mocap(config.dataset[args.data],
                 SetType.TEST,
                 transform=data_transform)
    data_loader = DataLoader(data,
                             batch_size=16,
                             shuffle=config.data_loader.shuffle,
                             num_workers=8)

    # ------------------- Evaluation -------------------

    eval_body = evaluate.EvalBody()
    eval_upper = evaluate.EvalUpperBody()
    eval_lower = evaluate.EvalLowerBody()

    # ------------------- Model -------------------
    with open('model/model.yaml') as fin:
        model_cfg = edict(yaml.safe_load(fin))
    resnet = pose_resnet.get_pose_net(model_cfg, False)
    # resnet = pose_resnet.get_pose_net(False)
    autoencoder = encoder_decoder.AutoEncoder()

    if args.load_model:
        if not os.path.isfile(args.load_model):
            raise ValueError(f"No checkpoint found at {args.load_model}")
        checkpoint = torch.load(args.load_model, map_location=device)
        resnet.load_state_dict(checkpoint['resnet_state_dict'])
        autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])
    else:
        raise ValueError("No checkpoint!")

    resnet.cuda(device)
    autoencoder.cuda(device)
    resnet.eval()
    autoencoder.eval()

    # ------------------- Read dataset frames -------------------
    fig = plt.figure(figsize=(19.2, 10.8))
    plt.axis('off')
    subplot_idx = 1

    with torch.no_grad():
        for it, (img, p2d, p3d, heatmap, action) in enumerate(data_loader):

            print('Iteration: {}'.format(it))
            print('Images: {}'.format(img.shape))
            print('p2ds: {}'.format(p2d.shape))
            print('p3ds: {}'.format(p3d.shape))
            print('Actions: {}'.format(action))

            img = img.to(device)
            p3d = p3d.to(device)
            # heatmap = heatmap.to(device)

            heatmap2d_hat = resnet(img)  # torch.Size([16, 15, 48, 48])
            p3d_hat, heatmap2d_recon = autoencoder(heatmap2d_hat)

            # Evaluate results using different evaluation metrices
            y_output = p3d_hat.data.cpu().numpy()
            y_target = p3d.data.cpu().numpy()

            eval_body.eval(y_output, y_target, action)
            eval_upper.eval(y_output, y_target, action)
            eval_lower.eval(y_output, y_target, action)

            # ------------------- Visualize 3D pose -------------------
            if subplot_idx <= 32:
                ax1 = fig.add_subplot(4, 8, subplot_idx, projection='3d')
                show3Dpose(p3d[0].cpu().numpy(), ax1, True)

                # Plot 3d gt
                # ax2 = fig.add_subplot(4, 8, subplot_idx+1, projection='3d')
                show3Dpose(p3d_hat[0].detach().cpu().numpy(), ax1, False)

                subplot_idx += 1
            if subplot_idx == 33:
                plt.savefig(os.path.join(LOGGER.logfile_dir, 'vis.png'))

            # ------------------- Visualize 2D heatmap -------------------

            if it < 32:

                # gt
                img_grid = draw2Dpred_and_gt(img, heatmap,
                                             (368, 368))  # tensor
                img_grid = img_grid.numpy().transpose(1, 2, 0)
                cv2.imwrite(os.path.join(LOGGER.logfile_dir, f'gt_{it}.jpg'),
                            img_grid)

                # 2d reconstruction
                img_grid_recon = draw2Dpred_and_gt(img, heatmap2d_recon,
                                                   (368, 368), p2d.clone())
                img_grid_recon = img_grid_recon.numpy().transpose(1, 2, 0)
                cv2.imwrite(
                    os.path.join(LOGGER.logfile_dir, f"recon_{it}.jpg"),
                    img_grid_recon)

                # 2d prediction
                img_grid_hat = draw2Dpred_and_gt(img,
                                                 heatmap2d_hat, (368, 368),
                                                 p2d.clone())  # tensor
                img_grid_hat = img_grid_hat.numpy().transpose(1, 2, 0)
                cv2.imwrite(os.path.join(LOGGER.logfile_dir, f'pred_{it}.jpg'),
                            img_grid_hat)

    # ------------------- Save results -------------------

    LOGGER.info('Saving evaluation results...')
    res = {
        'FullBody': eval_body.get_results(),
        'UpperBody': eval_upper.get_results(),
        'LowerBody': eval_lower.get_results()
    }

    LOGGER.info(pprint.pformat(res))
    print('Done.')