def test_do_training_epoch(cuda_device, mpii_image_dir):
    model = hg1(pretrained=True)
    model = model.to(cuda_device)
    train_dataset = Mpii(mpii_image_dir, is_train=True)
    train_dataset.train_list = train_dataset.train_list[:32]
    optimiser = Adam(model.parameters())
    train_loader = DataLoader(train_dataset,
                              batch_size=8,
                              shuffle=False,
                              num_workers=2,
                              pin_memory=True)
    do_training_epoch(train_loader,
                      model,
                      cuda_device,
                      Mpii.DATA_INFO,
                      optimiser,
                      quiet=True,
                      acc_joints=Mpii.ACC_JOINTS)
Esempio n. 2
0
def main(args):
    """Train/ Cross validate for data source = YogiDB."""
    # Create data loader
    """Generic(data.Dataset)(image_set, annotations,
                     is_train=True, inp_res=256, out_res=64, sigma=1,
                     scale_factor=0, rot_factor=0, label_type='Gaussian',
                     rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)."""
    annotations_source = 'basic-thresholder'

    # Get the data from yogi
    db_obj = YogiDB(config.db_url)
    imageset = db_obj.get_filtered(ImageSet,
                                   name=args.image_set_name)
    annotations = db_obj.get_annotations(image_set_name=args.image_set_name,
                                         annotation_source=annotations_source)
    pts = torch.Tensor(annotations[0]['joint_self'])
    num_classes = pts.size(0)
    crop_size = 512
    if args.crop:
        crop_size = args.crop
        crop = True
    else:
        crop = False

    # Using the default RGB mean and std dev as 0
    RGB_MEAN = torch.as_tensor([0.0, 0.0, 0.0])
    RGB_STDDEV = torch.as_tensor([0.0, 0.0, 0.0])

    dataset = Generic(image_set=imageset,
                      inp_res=args.inp_res,
                      out_res=args.out_res,
                      annotations=annotations,
                      mode=args.mode,
                      crop=crop, crop_size=crop_size,
                      rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)

    train_dataset = dataset
    train_dataset.is_train = True
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch, shuffle=True,
                              num_workers=args.workers, pin_memory=True)

    val_dataset = dataset
    val_dataset.is_train = False
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch, shuffle=False,
                            num_workers=args.workers, pin_memory=True)

    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False, num_classes=num_classes)
    else:
        raise Exception('unrecognised model architecture: ' + args.model)

    model = DataParallel(model).to(device)

    if args.optimizer == "Adam":
        optimizer = Adam(model.parameters(),
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
    else:
        optimizer = RMSprop(model.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    best_acc = 0

    # optionally resume from a checkpoint
    title = args.data_identifier + ' ' + args.arch
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader, model, device, optimizer)

        # evaluate on validation set
        if args.debug == 1:
            valid_loss, valid_acc, predictions, validation_log = do_validation_epoch(val_loader, model, device, False, True, os.path.join(args.checkpoint, 'debug.csv'), epoch + 1)
        else:
            valid_loss, valid_acc, predictions, _ = do_validation_epoch(val_loader, model, device, False)

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
def main(args):
    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False)
    else:
        raise Exception('unrecognised model architecture: ' + args.arch)

    model = DataParallel(model).to(device)

    optimizer = RMSprop(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

    best_acc = 0

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # create data loader
    train_dataset = Mpii(args.image_path, is_train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    val_dataset = Mpii(args.image_path, is_train=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    # train and eval
    lr = args.lr
    for epoch in trange(args.start_epoch,
                        args.epochs,
                        desc='Overall',
                        ascii=True):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader,
                                                  model,
                                                  device,
                                                  Mpii.DATA_INFO,
                                                  optimizer,
                                                  acc_joints=Mpii.ACC_JOINTS)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = do_validation_epoch(
            val_loader,
            model,
            device,
            Mpii.DATA_INFO,
            False,
            acc_joints=Mpii.ACC_JOINTS)

        # print metrics
        tqdm.write(
            f'[{epoch + 1:3d}/{args.epochs:3d}] lr={lr:0.2e} '
            f'train_loss={train_loss:0.4f} train_acc={100 * train_acc:0.2f} '
            f'valid_loss={valid_loss:0.4f} valid_acc={100 * valid_acc:0.2f}')

        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
        logger.plot_to_file(os.path.join(args.checkpoint, 'log.svg'),
                            ['Train Acc', 'Val Acc'])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()