Esempio n. 1
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Create data loaders
    if args.loss == 'triplet':
        assert args.num_instances > 1, 'TripletLoss requires num_instances > 1'
        assert args.batch_size % args.num_instances == 0, \
            'num_instances should divide batch_size'
    dataset, num_classes, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir,
                 args.batch_size, args.workers, args.num_instances,
                 combine_trainval=args.combine_trainval)

    # Create model
    if args.loss == 'xentropy':
        model = InceptionNet(num_classes=num_classes,
                             num_features=args.features,
                             dropout=args.dropout)
    elif args.loss == 'oim':
        model = InceptionNet(num_features=args.features,
                             norm=True,
                             dropout=args.dropout)
    elif args.loss == 'triplet':
        model = InceptionNet(num_features=args.features, dropout=args.dropout)
    else:
        raise ValueError("Cannot recognize loss type:", args.loss)
    model = torch.nn.DataParallel(model).cuda()

    # Load from checkpoint
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> start epoch {}  best top1 {:.1%}".format(
            args.start_epoch, best_top1))
    else:
        best_top1 = 0

    # Distance metric
    metric = DistanceMetric(algorithm=args.dist_metric)

    # Evaluator
    evaluator = Evaluator(model)
    if args.evaluate:
        metric.train(model, train_loader)
        print("Validation:")
        evaluator.evaluate(val_loader, dataset.val, dataset.val, metric)
        print("Test:")
        evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
        return

    # Criterion
    if args.loss == 'xentropy':
        criterion = torch.nn.CrossEntropyLoss()
    elif args.loss == 'oim':
        criterion = OIMLoss(model.module.num_features,
                            num_classes,
                            scalar=args.oim_scalar,
                            momentum=args.oim_momentum)
    elif args.loss == 'triplet':
        criterion = TripletLoss(margin=args.triplet_margin)
    else:
        raise ValueError("Cannot recognize loss type:", args.loss)
    criterion.cuda()

    # Optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    else:
        raise ValueError("Cannot recognize optimizer type:", args.optimizer)

    # Trainer
    trainer = Trainer(model, criterion)

    # Schedule learning rate
    def adjust_lr(epoch):
        if args.optimizer == 'sgd':
            lr = args.lr * (0.1**(epoch // 60))
        elif args.optimizer == 'adam':
            lr = args.lr if epoch <= 100 else \
                args.lr * (0.001 ** (epoch - 100) / 50)
        else:
            raise ValueError("Cannot recognize optimizer type:",
                             args.optimizer)
        for g in optimizer.param_groups:
            g['lr'] = lr

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer)
        top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val)

        is_best = top1 > best_top1
        best_top1 = max(top1, best_top1)
        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1,
            },
            is_best,
            fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d}  top1: {:5.1%}  best: {:5.1%}{}\n'.
              format(epoch, top1, best_top1, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    metric.train(model, train_loader)
    evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
Esempio n. 2
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    cudnn.benchmark = True

    # Redirect print to both console and log file
    # All the print infomration are stored in the logs_dir

    sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Create data loaders
    if args.loss == 'triplet':
        assert args.num_instances > 1, 'TripletLoss requires num_instances > 1'
        assert args.batch_size % args.num_instances == 0, \
            'num_instances should divide batch_size'

    dataset, num_classes, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir,
             args.batch_size, args.seq_len, args.seq_srd,
                 args.workers, args.num_instances,
                 combine_trainval=True)

    # Create model
    if args.loss == 'xentropy':
        model = ResNetLSTM_btfu(args.depth,
                                pretrained=True,
                                num_features=args.features,
                                dropout=args.dropout)

    elif args.loss == 'oim':
        model = ResNetLSTM_btfu(args.depth,
                                pretrained=True,
                                num_features=args.features,
                                norm=True,
                                dropout=args.dropout)

    elif args.loss == 'triplet':
        model = ResNetLSTM_btfu(args.depth,
                                pretrained=True,
                                num_features=args.features,
                                dropout=args.dropout)

    else:
        raise ValueError("cannot recognize loss type:", args.loss)

    model = torch.nn.DataParallel(model).cuda()

    # Load from checkpoint
    # TODO is not necessary currently

    # Distance metric
    metric = DistanceMetric(algorithm=args.dist_metric)

    # Evaluator
    evaluator = Evaluator(model)

    # Criterion
    if args.loss == 'xentropy':
        criterion = torch.nn.CrossEntropyLoss()
    elif args.loss == 'oim':
        criterion = OIMLoss(model.module.num_features,
                            num_classes,
                            scalar=args.oim_scalar,
                            momentum=args.oim_momentum)
    elif args.loss == 'triplet':
        criterion = TripletLoss(margin=args.triplet_margin)
    else:
        raise ValueError("Cannot recognize loss type:", args.loss)
    criterion.cuda()

    # Optimizer
    if args.optimizer == 'sgd':
        if args.loss == 'xentropy':
            base_param_ids = set(map(id, model.module.base.parameters()))
            new_params = [
                p for p in model.parameters() if id(p) not in base_param_ids
            ]
            param_groups = [{
                'params': model.module.base.parameters(),
                'lr_mult': 0.1
            }, {
                'params': new_params,
                'lr_mult': 1.0
            }]
        else:
            param_groups = model.parameters()
        optimizer = torch.optim.SGD(param_groups,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)

    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    else:
        raise ValueError("Cannot recognize optimizer type:", args.optimizer)

    # Trainer
    trainer = SeqTrainer(model, criterion)

    # Schedule learning rate
    def adjust_lr(epoch):
        if args.optimizer == 'sgd':
            lr = args.lr * (0.1**(epoch // 40))
        elif args.optimizer == 'adam':
            lr = args.lr if epoch <= 100 else \
                args.lr * (0.001 ** (epoch - 100) / 50)
        else:
            raise ValueError("Cannot recognize optimizer type:",
                             args.optimizer)
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Starting training
    for epoch in range(args.start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer)

        top1 = evaluator.evaluate(test_loader,
                                  dataset.query,
                                  dataset.gallery,
                                  multi_shot=True)