Beispiel #1
0
def main():
    global args, max_meter
    args = parser.parse_args()

    # create export dir if it doesnt exist
    directory = "{}".format(args.training_dataset)
    directory += "_{}".format(args.arch)
    directory += "_{}".format(args.pool)
    if args.local_whitening:
        directory += "_lwhiten"
    if args.regional:
        directory += "_r"
    if args.whitening:
        directory += "_whiten"
    if not args.pretrained:
        directory += "_notpretrained"
    directory += "_{}_m{:.2f}".format(args.loss, args.loss_margin)
    directory += "_{}_lr{:.1e}_wd{:.1e}".format(args.optimizer, args.lr,
                                                args.weight_decay)
    directory += "_nnum{}_qsize{}_psize{}".format(args.neg_num,
                                                  args.query_size,
                                                  args.pool_size)
    directory += "_bsize{}_imsize{}".format(args.batch_size, args.image_size)

    args.directory = os.path.join(args.directory, directory)
    print(">> Creating directory if it does not exist:\n>> '{}'".format(
        args.directory))
    if not os.path.exists(args.directory):
        os.makedirs(args.directory)

    # set random seeds
    # TODO: maybe pass as argument in future implementation?
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)

    # initialize model
    if args.pretrained:
        print(">> Using pre-trained model '{}'".format(args.arch))
    else:
        print(">> Using model from scratch (random weights) '{}'".format(
            args.arch))
    model_params = {}
    model_params['architecture'] = args.arch
    model_params['pooling'] = args.pool
    model_params['local_whitening'] = args.local_whitening
    model_params['regional'] = args.regional
    model_params['whitening'] = args.whitening
    # model_params['mean'] = ...  # will use default
    # model_params['std'] = ...  # will use default
    model_params['pretrained'] = args.pretrained
    model = init_network(model_params)

    if args.network_path is not None:
        state = torch.load(args.network_path)
        extra_keys = ["pool.p", "whiten.weight", "whiten.bias"]
        for k in extra_keys:
            if k in model.state_dict() and k not in state['state_dict']:
                state['state_dict'][k] = model.state_dict()[k]
        model.load_state_dict(state['state_dict'])

        print(">>>> loaded network from", args.network_path)
        print(model.meta_repr())

    # move network to gpu
    model.cuda()

    # define loss function (criterion) and optimizer
    if args.loss == 'contrastive':
        criterion = ContrastiveLoss(margin=args.loss_margin).cuda()
    else:
        raise (RuntimeError("Loss {} not available!".format(args.loss)))

    # parameters split into features, pool, whitening
    # IMPORTANT: no weight decay for pooling parameter p in GeM or regional-GeM
    parameters = []
    # add feature parameters
    parameters.append({'params': model.features.parameters()})
    # add local whitening if exists
    if model.lwhiten is not None:
        parameters.append({'params': model.lwhiten.parameters()})
    # add pooling parameters (or regional whitening which is part of the pooling layer!)
    if not args.regional:
        # global, only pooling parameter p weight decay should be 0
        parameters.append({
            'params': model.pool.parameters(),
            'lr': args.lr * 10,
            'weight_decay': 0
        })
    else:
        # regional, pooling parameter p weight decay should be 0,
        # and we want to add regional whitening if it is there
        parameters.append({
            'params': model.pool.rpool.parameters(),
            'lr': args.lr * 10,
            'weight_decay': 0
        })
        if model.pool.whiten is not None:
            parameters.append({'params': model.pool.whiten.parameters()})
    # add final whitening if exists
    if model.whiten is not None:
        parameters.append({'params': model.whiten.parameters()})

    # define optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(parameters,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(parameters,
                                     args.lr,
                                     weight_decay=args.weight_decay)

    # define learning rate decay schedule
    # TODO: maybe pass as argument in future implementation?
    exp_decay = math.exp(-0.01)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       gamma=exp_decay)

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume:
        args.resume = os.path.join(args.directory, args.resume)
        if os.path.isfile(args.resume):
            # load checkpoint weights and update model and optimizer
            print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            max_meter = checkpoint['max_meter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            # important not to forget scheduler updating
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=exp_decay, last_epoch=checkpoint['epoch'] - 1)
        else:
            print(">> No checkpoint found at '{}'".format(args.resume))

    # Data loading code
    normalize = transforms.Normalize(mean=model.meta['mean'],
                                     std=model.meta['std'])
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = TuplesDataset(name=args.training_dataset,
                                  mode='train',
                                  imsize=args.image_size,
                                  nnum=args.neg_num,
                                  qsize=args.query_size,
                                  poolsize=args.pool_size,
                                  transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None,
                                               drop_last=True,
                                               collate_fn=collate_tuples)
    if args.val:
        val_dataset = TuplesDataset(name=args.training_dataset,
                                    mode='val',
                                    imsize=args.image_size,
                                    nnum=args.neg_num,
                                    qsize=float('Inf'),
                                    poolsize=float('Inf'),
                                    transform=transform)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True,
                                                 collate_fn=collate_tuples)

    for epoch in range(start_epoch, args.epochs):

        # set manual seeds per epoch
        np.random.seed(epoch)
        torch.manual_seed(epoch)
        torch.cuda.manual_seed_all(epoch)

        # evaluate on validation set
        if args.val:

            with torch.no_grad():
                loss = validate(val_loader, model, criterion, epoch)

        # evaluate on test datasets every test_freq epochs
        if epoch % args.test_freq == 0:
            with torch.no_grad():
                cur_meter = test(args.test_datasets,
                                 model,
                                 image_size=args.image_size)

            # remember best loss and save checkpoint
            is_best = cur_meter > max_meter
            max_meter = max(cur_meter, max_meter)

        save_checkpoint(
            {
                'epoch': epoch,
                'meta': model.meta,
                'state_dict': model.state_dict(),
                'max_meter': max_meter,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            args.directory,
            save_regular=args.save_freq and epoch % args.save_freq == 0
            or epoch == 0)

        # train for one epoch on train set
        loss = train(train_loader, model, criterion, optimizer, epoch)

        # adjust learning rate for each epoch
        scheduler.step()
Beispiel #2
0
def main():
    global args, min_loss
    args = parser.parse_args()

    # manually check if there are unknown test datasets
    for dataset in args.test_datasets.split(','):
        if dataset not in test_datasets_names:
            raise ValueError('Unsupported or unknown test dataset: {}!'.format(dataset))

    # check if test dataset are downloaded
    # and download if they are not
    # download_train(get_data_root())
    # download_test(get_data_root())
    # download_train('/media/hxq/ExtractedDatasets/Datasets/retrieval-SfM')
    # download_test('/media/hxq/ExtractedDatasets/Datasets/retrieval-SfM')

    # create export dir if it doesnt exist
    directory = "{}".format(args.training_dataset)
    directory += "_{}".format(args.arch)
    directory += "_{}".format(args.pool)
    if args.local_whitening:
        directory += "_lwhiten"
    if args.regional:
        directory += "_r"
    if args.whitening:
        directory += "_whiten"
    if not args.pretrained:
        directory += "_notpretrained"
    directory += "_{}_m{:.2f}".format(args.loss, args.loss_margin)
    directory += "_{}_lr{:.1e}_wd{:.1e}".format(args.optimizer, args.lr, args.weight_decay)
    directory += "_nnum{}_qsize{}_psize{}".format(args.neg_num, args.query_size, args.pool_size)
    directory += "_bsize{}_imsize{}".format(args.batch_size, args.image_size)

    args.directory = os.path.join(args.directory, directory)
    print(">> Creating directory if it does not exist:\n>> '{}'".format(args.directory))
    if not os.path.exists(args.directory):
        os.makedirs(args.directory)

    # set cuda visible device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id

    # set random seeds
    # TODO: maybe pass as argument in future implementation?
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)

    # initialize model
    if args.pretrained:
        print(">> Using pre-trained model '{}'".format(args.arch))
    else:
        print(">> Using model from scratch (random weights) '{}'".format(args.arch))
    model_params = {}
    model_params['architecture'] = args.arch
    model_params['pooling'] = args.pool
    model_params['local_whitening'] = args.local_whitening
    model_params['regional'] = args.regional
    model_params['whitening'] = args.whitening
    # model_params['mean'] = ...  # will use default
    # model_params['std'] = ...  # will use default
    model_params['pretrained'] = args.pretrained
    model_params['multi_layer_cat'] = args.multi_layer_cat
    model = init_network(model_params)
    print(">>>> loaded model: ")
    print(model.meta_repr())

    # move network to gpu
    model.cuda()

    # define loss function (criterion) and optimizer
    if args.loss == 'contrastive':
        criterion = ContrastiveLoss(margin=args.loss_margin).cuda()
    else:
        raise(RuntimeError("Loss {} not available!".format(args.loss)))

    # parameters split into features, pool, whitening
    # IMPORTANT: no weight decay for pooling parameters p in GeM or regional-GeM
    parameters = []
    # add feature parameters
    parameters.append({'params': model.features.parameters()})
    # add local whitening if exists
    if model.lwhiten is not None:
        parameters.append({'params': model.lwhiten.parameters()})
    # add pooling parameters (or regional whitening which is part of the pooling layer!)
    if not args.regional:
        # global, only pooling parameters p weight decay should be 0
        parameters.append({'params': model.pool.parameters(), 'lr': args.lr*10, 'weight_decay': 0})
    else:
        # regional, pooling parameters p weight decay should be 0,
        # and we want to add regional whitening if it is there
        parameters.append({'params': model.pool.rpool.parameters(), 'lr': args.lr*10, 'weight_decay': 0})
        if model.pool.whiten is not None:
            parameters.append({'params': model.pool.whiten.parameters()})
    # add final whitening if exists
    if model.whiten is not None:
        parameters.append({'params': model.whiten.parameters()})

    # define optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay)

    # define learning rate decay schedule
    # TODO: maybe pass as argument in future implementation?
    exp_decay = math.exp(-0.01)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay)

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume:
        args.resume = os.path.join(args.directory, args.resume)
        if os.path.isfile(args.resume):
            # load checkpoint weights and update model and optimizer
            print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']

            # hxq added, for whitening test
            # print('>> load the parameters of supervised whitening for the FC layer initialization')
            # whiten_fn = '/media/iap205/Data/Export/cnnimageretrieval-google_landmark_retrieval/trained_network/' \
            #             'R101_O_GL_FC/google-landmarks-dataset-resize_resnet101_gem_whiten_contrastive_m0.85_' \
            #             'adam_lr5.0e-07_wd1.0e-04_nnum5_qsize2000_psize22000_bsize5_imsize362/' \
            #             'model_epoch114.pth.tar_google-landmarks-dataset_whiten_ms.pth'
            # Lw = torch.load(whiten_fn)
            # P = Lw['P']
            # m = Lw['m']
            # P = torch.from_numpy(P).float()
            # m = torch.from_numpy(m).float()
            # checkpoint['state_dict']['whiten.weight'] = P
            # checkpoint['state_dict']['whiten.bias'] = -torch.mm(P, m).squeeze()

            min_loss = checkpoint['min_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            # important not to forget scheduler updating
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay, last_epoch=checkpoint['epoch']-1)
            # print(">> Use new scheduler")
        else:
            print(">> No checkpoint found at '{}'".format(args.resume))

    # Data loading
    # hxq added, but it can be generate in advance
    # if args.training_dataset == 'google-landmarks-dataset-v2':
    #     data_table_path = os.path.join(get_data_root(), 'train.csv')
    #     img_check_path = os.path.join(get_data_root(), 'train')
    #     train_val_file_output_path = os.path.join(get_data_root())
    #     test_file_output_path = os.path.join(get_data_root(), 'test', args.training_dataset+'-test')
    #     val_set_size = 10000
    #     test_set_size = 10000
    # elif args.training_dataset == 'google-landmarks-dataset-resize':
    #     data_table_path = os.path.join(get_data_root(), 'train.csv')
    #     img_check_path = os.path.join(get_data_root(), 'resize_train_image')
    #     train_val_file_output_path = os.path.join(get_data_root())
    #     test_file_output_path = os.path.join(get_data_root(), 'test', args.training_dataset+'-test')
    #     val_set_size = 10000
    #     test_set_size = 10000
    # elif args.training_dataset == 'google-landmarks-dataset':
    #     pass
    # if args.training_dataset == 'google-landmarks-dataset-v2'\
    #     or args.training_dataset == 'google-landmarks-dataset-resize'\
    #     or args.training_dataset == 'google-landmarks-dataset':
    #     if not (os.path.isfile(os.path.join(train_val_file_output_path, '{}.pkl'.format(args.training_dataset)))
    #             and os.path.isfile(os.path.join(test_file_output_path, 'gnd_{}-test.pkl'.format(args.training_dataset)))):
    #         gen_train_val_test_pkl(args.training_dataset, data_table_path, img_check_path, val_set_size, test_set_size,
    #                                train_val_file_output_path, test_file_output_path)

    normalize = transforms.Normalize(mean=model.meta['mean'], std=model.meta['std'])
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = TuplesDataset(
        name=args.training_dataset,
        mode='train',
        imsize=args.image_size,
        nnum=args.neg_num,
        qsize=args.query_size,
        poolsize=args.pool_size,
        transform=transform
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, sampler=None,
        drop_last=True, collate_fn=collate_tuples
    )
    if args.val:
        val_dataset = TuplesDataset(
            name=args.training_dataset,
            mode='val',
            imsize=args.image_size,
            nnum=args.neg_num,
            qsize=float('Inf'),
            poolsize=float('Inf'),
            transform=transform
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True,
            drop_last=True, collate_fn=collate_tuples
        )

    # evaluate the network before starting
    # this might not be necessary?
    # test(args.test_datasets, model)

    # add by hxq
    # save_checkpoint({
    #     'epoch': 0,
    #     'meta': model.meta,
    #     'state_dict': model.state_dict(),
    #     'min_loss': min_loss,
    #     'optimizer': optimizer.state_dict(),
    # }, 0, args.directory)

    # hxq added, save trained network as epoch
    # network_path = {'vgg16': '/home/iap205/Datasets/retrieval-SfM/networks/retrievalSfM120k-vgg16-gem-b4dcdc6.pth',
    #                 'resnet101': '/home/iap205/Datasets/retrieval-SfM/networks/retrievalSfM120k-resnet101-gem-b80fb85.pth'}
    # model_trained = torch.load(network_path['resnet101'])
    # save_checkpoint({
    #     'epoch': 0,
    #     'meta': model.meta,
    #     'state_dict': model_trained['state_dict'],
    #     'min_loss': min_loss,
    #     'optimizer': optimizer.state_dict(),
    # }, 0, args.directory)

    for epoch in range(start_epoch, args.epochs):

        # set manual seeds per epoch
        np.random.seed(epoch)
        torch.manual_seed(epoch)
        torch.cuda.manual_seed_all(epoch)

        # adjust learning rate for each epoch
        scheduler.step()
        # # debug printing to check if everything ok
        # lr_feat = optimizer.param_groups[0]['lr']
        # lr_pool = optimizer.param_groups[1]['lr']
        # print('>> Features lr: {:.2e}; Pooling lr: {:.2e}'.format(lr_feat, lr_pool))

        # train for one epoch on train set
        loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if args.val:
            with torch.no_grad():
                loss = validate(val_loader, model, criterion, epoch)

        # evaluate on test datasets every test_freq epochs
        if (epoch + 1) % args.test_freq == 0:
            with torch.no_grad():
                test(args.test_datasets, model)

        # remember best loss and save checkpoint
        is_best = loss < min_loss
        min_loss = min(loss, min_loss)

        save_checkpoint({
            'epoch': epoch + 1,
            'meta': model.meta,
            'state_dict': model.state_dict(),
            'min_loss': min_loss,
            'optimizer' : optimizer.state_dict(),
        }, is_best, args.directory)
Beispiel #3
0
def main():
    global args, min_loss
    args = parser.parse_args()

    # manually check if there are unknown test datasets
    for dataset in args.test_datasets.split(','):
        if dataset not in test_datasets_names:
            raise ValueError('Unsupported or unknown test dataset: {}!'.format(dataset))

    # check if test dataset are downloaded
    # and download if they are not
    download_train(get_data_root())
    download_test(get_data_root())

    # create export dir if it doesnt exist
    directory = "{}".format(args.training_dataset)
    directory += "_{}".format(args.arch)
    directory += "_{}".format(args.pool)
    if args.local_whitening:
        directory += "_lwhiten"
    if args.regional:
        directory += "_r"
    if args.whitening:
        directory += "_whiten"
    if not args.pretrained:
        directory += "_notpretrained"
    directory += "_{}_m{:.2f}".format(args.loss, args.loss_margin)
    directory += "_{}_lr{:.1e}_wd{:.1e}".format(args.optimizer, args.lr, args.weight_decay)
    directory += "_nnum{}_qsize{}_psize{}".format(args.neg_num, args.query_size, args.pool_size)
    directory += "_bsize{}_uevery{}_imsize{}".format(args.batch_size, args.update_every, args.image_size)

    args.directory = os.path.join(args.directory, directory)
    print(">> Creating directory if it does not exist:\n>> '{}'".format(args.directory))
    if not os.path.exists(args.directory):
        os.makedirs(args.directory)

    # set cuda visible device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    
    # set random seeds
    # TODO: maybe pass as argument in future implementation?
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)

    # initialize model
    if args.pretrained:
        print(">> Using pre-trained model '{}'".format(args.arch))
    else:
        print(">> Using model from scratch (random weights) '{}'".format(args.arch))
    model_params = {}
    model_params['architecture'] = args.arch
    model_params['pooling'] = args.pool
    model_params['local_whitening'] = args.local_whitening
    model_params['regional'] = args.regional
    model_params['whitening'] = args.whitening
    # model_params['mean'] = ...  # will use default
    # model_params['std'] = ...  # will use default
    model_params['pretrained'] = args.pretrained
    model = init_network(model_params)

    # move network to gpu
    model.cuda()

    # define loss function (criterion) and optimizer
    if args.loss == 'contrastive':
        criterion = ContrastiveLoss(margin=args.loss_margin).cuda()
    elif args.loss == 'triplet':
        criterion = TripletLoss(margin=args.loss_margin).cuda()
    else:
        raise(RuntimeError("Loss {} not available!".format(args.loss)))

    # parameters split into features, pool, whitening 
    # IMPORTANT: no weight decay for pooling parameter p in GeM or regional-GeM
    parameters = []
    # add feature parameters
    parameters.append({'params': model.features.parameters()})
    # add local whitening if exists
    if model.lwhiten is not None:
        parameters.append({'params': model.lwhiten.parameters()})
    # add pooling parameters (or regional whitening which is part of the pooling layer!)
    if not args.regional:
        # global, only pooling parameter p weight decay should be 0
        if args.pool == 'gem':
            parameters.append({'params': model.pool.parameters(), 'lr': args.lr*10, 'weight_decay': 0})
        elif args.pool == 'gemmp':
            parameters.append({'params': model.pool.parameters(), 'lr': args.lr*100, 'weight_decay': 0})
    else:
        # regional, pooling parameter p weight decay should be 0, 
        # and we want to add regional whitening if it is there
        if args.pool == 'gem':
            parameters.append({'params': model.pool.rpool.parameters(), 'lr': args.lr*10, 'weight_decay': 0})
        elif args.pool == 'gemmp':
            parameters.append({'params': model.pool.rpool.parameters(), 'lr': args.lr*100, 'weight_decay': 0})
        if model.pool.whiten is not None:
            parameters.append({'params': model.pool.whiten.parameters()})
    # add final whitening if exists
    if model.whiten is not None:
        parameters.append({'params': model.whiten.parameters()})

    # define optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay)

    # define learning rate decay schedule
    # TODO: maybe pass as argument in future implementation?
    exp_decay = math.exp(-0.01)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay)

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume:
        args.resume = os.path.join(args.directory, args.resume)
        if os.path.isfile(args.resume):
            # load checkpoint weights and update model and optimizer
            print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            min_loss = checkpoint['min_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            # important not to forget scheduler updating
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay, last_epoch=checkpoint['epoch']-1)
        else:
            print(">> No checkpoint found at '{}'".format(args.resume))

    # Data loading code
    normalize = transforms.Normalize(mean=model.meta['mean'], std=model.meta['std'])
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = TuplesDataset(
        name=args.training_dataset,
        mode='train',
        imsize=args.image_size,
        nnum=args.neg_num,
        qsize=args.query_size,
        poolsize=args.pool_size,
        transform=transform
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, sampler=None,
        drop_last=True, collate_fn=collate_tuples
    )
    if args.val:
        val_dataset = TuplesDataset(
            name=args.training_dataset,
            mode='val',
            imsize=args.image_size,
            nnum=args.neg_num,
            qsize=float('Inf'),
            poolsize=float('Inf'),
            transform=transform
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True,
            drop_last=True, collate_fn=collate_tuples
        )

    # evaluate the network before starting
    # this might not be necessary?
    test(args.test_datasets, model)

    for epoch in range(start_epoch, args.epochs):

        # set manual seeds per epoch
        np.random.seed(epoch)
        torch.manual_seed(epoch)
        torch.cuda.manual_seed_all(epoch)

        # adjust learning rate for each epoch
        scheduler.step()
        # # debug printing to check if everything ok
        # lr_feat = optimizer.param_groups[0]['lr']
        # lr_pool = optimizer.param_groups[1]['lr']
        # print('>> Features lr: {:.2e}; Pooling lr: {:.2e}'.format(lr_feat, lr_pool))

        # train for one epoch on train set
        loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if args.val:
            with torch.no_grad():
                loss = validate(val_loader, model, criterion, epoch)

        # evaluate on test datasets every test_freq epochs
        if (epoch + 1) % args.test_freq == 0:
            with torch.no_grad():
                test(args.test_datasets, model)

        # remember best loss and save checkpoint
        is_best = loss < min_loss
        min_loss = min(loss, min_loss)

        save_checkpoint({
            'epoch': epoch + 1,
            'meta': model.meta,
            'state_dict': model.state_dict(),
            'min_loss': min_loss,
            'optimizer' : optimizer.state_dict(),
        }, is_best, args.directory)
def main():
    global args, min_loss
    args = parser.parse_args()

    # check if test dataset are downloaded
    # and download if they are not
    download_train(get_data_root())
    download_test(get_data_root())

    # create export dir if it doesnt exist
    directory = "{}".format(args.training_dataset)
    directory += "_{}".format(args.arch)
    directory += "_{}".format(args.pool)
    if args.whitening:
        directory += "_whiten"
    if not args.pretrained:
        directory += "_notpretrained"
    directory += "_{}_m{:.2f}".format(args.loss, args.loss_margin)
    directory += "_{}_lr{:.1e}_wd{:.1e}".format(args.optimizer, args.lr,
                                                args.weight_decay)
    directory += "_nnum{}_qsize{}_psize{}".format(args.neg_num,
                                                  args.query_size,
                                                  args.pool_size)
    directory += "_bsize{}_imsize{}".format(args.batch_size, args.image_size)

    args.directory = os.path.join(args.directory, directory)
    print(">> Creating directory if it does not exist:\n>> '{}'".format(
        args.directory))
    if not os.path.exists(args.directory):
        os.makedirs(args.directory)

    # set cuda visible device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id

    # set random seeds (maybe pass as argument)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)

    # create model
    if args.pretrained:
        print(">> Using pre-trained model '{}'".format(args.arch))
        model = init_network(model=args.arch,
                             pooling=args.pool,
                             whitening=args.whitening)
    else:
        print(">> Using model from scratch (random weights) '{}'".format(
            args.arch))
        model = init_network(model=args.arch,
                             pooling=args.pool,
                             whitening=args.whitening,
                             pretrained=False)

    # move network to gpu
    model.cuda()

    # define loss function (criterion) and optimizer
    if args.loss == 'contrastive':
        criterion = ContrastiveLoss(margin=args.loss_margin).cuda()
    else:
        raise (RuntimeError("Loss {} not available!".format(args.loss)))

    # parameters split into features and pool (no weight decay for pooling layer)
    parameters = [{
        'params': model.features.parameters()
    }, {
        'params': model.pool.parameters(),
        'lr': args.lr * 10,
        'weight_decay': 0
    }]
    if model.whiten is not None:
        parameters.append({'params': model.whiten.parameters()})

    # define optimizer
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(parameters,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(parameters,
                                     args.lr,
                                     weight_decay=args.weight_decay)

    # define learning rate decay schedule
    exp_decay = math.exp(-0.01)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                       gamma=exp_decay)

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume:
        args.resume = os.path.join(args.directory, args.resume)
        if os.path.isfile(args.resume):
            print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            min_loss = checkpoint['min_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=exp_decay, last_epoch=checkpoint['epoch'] - 1)
        else:
            print(">> No checkpoint found at '{}'".format(args.resume))

    # Data loading code
    normalize = transforms.Normalize(mean=model.meta['mean'],
                                     std=model.meta['std'])
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = TuplesDataset(name=args.training_dataset,
                                  mode='train',
                                  imsize=args.image_size,
                                  nnum=args.neg_num,
                                  qsize=args.query_size,
                                  poolsize=args.pool_size,
                                  transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None,
                                               drop_last=True,
                                               collate_fn=collate_tuples)
    if args.val:
        val_dataset = TuplesDataset(name=args.training_dataset,
                                    mode='val',
                                    imsize=args.image_size,
                                    nnum=args.neg_num,
                                    qsize=float('Inf'),
                                    poolsize=float('Inf'),
                                    transform=transform)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True,
                                                 collate_fn=collate_tuples)

    # evaluate the network before starting
    test(args.test_datasets, model)

    for epoch in range(start_epoch, args.epochs):

        # set manual seeds per epoch
        np.random.seed(epoch)
        torch.manual_seed(epoch)
        torch.cuda.manual_seed_all(epoch)

        # adjust learning rate for each epoch
        scheduler.step()
        # lr_feat = optimizer.param_groups[0]['lr']
        # lr_pool = optimizer.param_groups[1]['lr']
        # print('>> Features lr: {:.2e}; Pooling lr: {:.2e}'.format(lr_feat, lr_pool))

        # train for one epoch on train set
        loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if args.val:
            loss = validate(val_loader, model, criterion, epoch)

        # evaluate on test datasets
        test(args.test_datasets, model)

        # remember best loss and save checkpoint
        is_best = loss < min_loss
        min_loss = min(loss, min_loss)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'meta': model.meta,
                'state_dict': model.state_dict(),
                'min_loss': min_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.directory)