Example #1
0
def _make_torch_data_loaders(args):
    """
    Helper function to produce data loaders for training
    :param args: Command line arguments
    :return: Training dataset, training data loader, validation data loader.
    """
    train_dataset = datasets.Mpii(
        'stacked_hourglass/data/mpii/mpii_annotations.json',
        'stacked_hourglass/data/mpii/images',
        sigma=args.sigma,
        label_type=args.label_type,
        augment_data=args.augment_training_data,
        args=args)

    val_dataset = datasets.Mpii(
        'stacked_hourglass/data/mpii/mpii_annotations.json',
        'stacked_hourglass/data/mpii/images',
        sigma=args.sigma,
        label_type=args.label_type,
        train=False,
        augment_data=False,
        args=args)

    if args.use_horovod:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            sampler=train_sampler,
            # shuffle=True,#sampler=train_sampler,
            num_workers=args.workers,
            pin_memory=True)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.test_batch_size,
            sampler=val_sampler,
            # shuffle=False, #sampler=val_sampler,
            num_workers=args.workers,
            pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.test_batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    return train_dataset, train_loader, val_loader
Example #2
0
def main(args):
    global best_acc

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes)

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.MSELoss(size_average=True).cuda()

    optimizer = torch.optim.RMSprop(model.parameters(), 
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'mpii-' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:        
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images',
                      sigma=args.sigma, label_type=args.label_type),
        batch_size=args.train_batch, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    
    val_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images',
                      sigma=args.sigma, label_type=args.label_type, train=False),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        print('\nEvaluation only') 
        loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip)
        save_pred(predictions, checkpoint=args.checkpoint)
        return

    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        from time import sleep
        sleep(2)

        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *=  args.sigma_decay
            val_loader.dataset.sigma *=  args.sigma_decay

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes,
                                                      args.debug, args.flip)

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }, predictions, is_best, checkpoint=args.checkpoint)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Example #3
0
def main(args):
    global best_acc
    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)
    _logger = log.get_logger(__name__, args)
    _logger.info(print_args(args))

    # create model
    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=len(args.index_classes))

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = models.loss.UniLoss(a_points=args.a_points)
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'mpii-' + args.arch
    if args.resume:
        if isfile(args.resume):
            _logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr
                print(param_group['lr'])
            _logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=False)
            logger.set_names([
                'Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'
            ])
        else:
            _logger.info("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    cudnn.benchmark = True
    _logger.info('    Total params: %.2fM' %
                 (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json',
                      'data/mpii/images',
                      sigma=args.sigma,
                      label_type=args.label_type,
                      _idx=args.index_classes,
                      direct=True,
                      n_points=args.n_points),
        batch_size=args.train_batch,
        shuffle=True,
        collate_fn=datasets.mpii.mycollate,
        num_workers=args.workers,
        pin_memory=False)

    val_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json',
                      'data/mpii/images',
                      sigma=args.sigma,
                      label_type=args.label_type,
                      _idx=args.index_classes,
                      train=False,
                      direct=True),
        batch_size=args.test_batch,
        shuffle=False,
        collate_fn=datasets.mpii.mycollate,
        num_workers=args.workers,
        pin_memory=False)

    if args.evaluate:
        _logger.warning('\nEvaluation only')
        loss, acc, predictions = validate(val_loader,
                                          model,
                                          criterion,
                                          len(args.index_classes),
                                          False,
                                          args.flip,
                                          _logger,
                                          evaluate_only=True)
        save_pred(predictions, checkpoint=args.checkpoint)
        return

    # multi-thread
    inqueues = []
    outqueues = []
    valid_accs = []
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        _logger.warning('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay
        # train for one epoch
        train_loss, train_acc = train(inqueues, outqueues, train_loader, model,
                                      criterion, optimizer, args.debug,
                                      args.flip, args.clip, _logger)
        # evaluate on validation set
        with torch.no_grad():
            valid_loss, valid_acc, predictions = validate(
                val_loader, model, criterion, len(args.index_classes),
                args.debug, args.flip, _logger)
        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
        valid_accs.append(valid_acc)
        if args.schedule[0] == -1:
            if len(valid_accs) > 8:
                if sum(valid_accs[-4:]) / 4 * 0.99 < sum(
                        valid_accs[-8:-4]) / 4:
                    args.schedule.append(epoch + 1)
                    valid_accs = []
        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=1)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))