Beispiel #1
0
#################### DATALOADER SETUPS ##################
#Returns a dictionary containing 'training', 'testing', and 'evaluation' dataloaders.
#The 'testing'-dataloader corresponds to the validation set, and the 'evaluation'-dataloader
#Is simply using the training set, however running under the same rules as 'testing' dataloader,
#i.e. no shuffling and no random cropping.
dataloaders      = data.give_dataloaders(opt.dataset, opt)
#Because the number of supervised classes is dataset dependent, we store them after
#initializing the dataloader
opt.num_classes  = len(dataloaders['training'].dataset.avail_classes)

"""============================================================================"""
#################### CREATE LOGGING FILES ###############
#Each dataset usually has a set of standard metrics to log. aux.metrics_to_examine()
#returns a dict which lists metrics to log for training ('train') and validation/testing ('val')

metrics_to_log = aux.metrics_to_examine(opt.dataset, opt.k_vals)
# example output: {'train': ['Epochs', 'Time', 'Train Loss', 'Time'],
#                  'val': ['Epochs','Time','NMI','F1', 'Recall @ 1','Recall @ 2','Recall @ 4','Recall @ 8']}

#Using the provided metrics of interest, we generate a LOGGER instance.
#Note that 'start_new' denotes that a new folder should be made in which everything will be stored.
#This includes network weights as well.
LOG = aux.LOGGER(opt, metrics_to_log, name='Base', start_new=True)
#If graphviz is installed on the system, a computational graph of the underlying
#network will be made as well.

"""============================================================================"""
#################### LOSS SETUP ####################
#Depending on opt.loss and opt.sampling, the respective criterion is returned,
#and if the loss has trainable parameters, to_optim is appended.
criterion, to_optim = losses.loss_select(opt.loss, opt, to_optim)
Beispiel #2
0
def main():
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu).lstrip('[').rstrip(']')
    print("torch.cuda.current_device()_{}".format(torch.cuda.current_device()))

    # create model
    model = net.bninception(args.dim)
    # torch.cuda.set_device(args.gpu)
    args.device = "cuda"
    model = model.to(args.device) if not len(args.gpu) > 1 else nn.DataParallel(model).to(args.device)

    # load data
    traindir = os.path.join(args.data, args.dataset, 'train')
    testdir = os.path.join(args.data, args.dataset, 'test')
    normalize = transforms.Normalize(mean=[104., 117., 128.],
                                     std=[1., 1., 1.])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.Lambda(RGB2BGR),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255)),
            normalize,
        ]))



    ###################
    if args.dataset == 'cars196':
        args.k_vals = [1, 2, 4, 8]
    elif args.dataset == 'cub200':
        args.k_vals = [1, 2, 4, 8]
    elif args.dataset == 'online_products':
        args.k_vals = [1, 10, 100, 1000]

    args.cN = len(train_dataset.class_to_idx)

    metrics_to_log = aux.metrics_to_examine(args.dataset, args.k_vals)
    args.save_path = os.getcwd() + '/Training_Results'

    args.savename = "ProxyGML_{}/".format(args.dataset) + "dim{}_".format(
        args.dim) + "weight_lambda{}_".format(
        args.weight_lambda) + "N{}_".format(args.N) + "r{}_".format(args.r) + "bs{}_".format(
        args.batch_size) + "graph_lr{}_".format(
        args.centerlr) + "epoch_to_decay{}_".format(
        args.new_epoch_to_decay)

    LOG = aux.LOGGER(args, metrics_to_log, name='Base', start_new=True)

    ##########################

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(testdir, transforms.Compose([
            transforms.Lambda(RGB2BGR),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255)),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)



    # define loss function (criterion) and optimizer

    criterion = loss.ProxyGML(args).to(args.device)


    optimizer = torch.optim.Adam([{"params": model.parameters(), "lr": args.modellr},
                                  {"params": criterion.parameters(), "lr": args.centerlr}],
                                 eps=args.eps, weight_decay=args.weight_decay)
    cudnn.benchmark = True

    for epoch in range(args.start_epoch, args.epochs):

        args.cur_epoch = epoch
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        start = time.time()
        mean_losss = train(train_loader, model, criterion, optimizer, args)
        LOG.log('train', LOG.metrics_to_log['train'], [epoch, np.round(time.time() - start, 4), mean_losss])

        # Warmup: Train only new params, helps stabilize learning.
        if args.warm > 0:
            unfreeze_model_param = list(model.embedding.parameters()) + list(criterion.parameters())

            if epoch == 0:
                for param in list(set(model.parameters()).difference(set(unfreeze_model_param))):
                    param.requires_grad = False
            if epoch == args.warm:
                for param in list(set(model.parameters()).difference(set(unfreeze_model_param))):
                    param.requires_grad = True

        # if (epoch+1)>=0 and (epoch+1) %1==0:
        if (epoch + 1) in args.epoch_to_test:
            start = time.time()
            nmi, recall = validate(test_loader, model, args)
            LOG.log('val', LOG.metrics_to_log['val'], [epoch, np.round(time.time() - start), nmi] + list(recall))
            print("\n")
            print(
                'Recall@ {kval}: {recall[0]:.3f}, {recall[1]:.3f}, {recall[2]:.3f}, {recall[3]:.3f}; NMI: {nmi:.3f} \n'
                    .format(kval=args.k_vals, recall=recall, nmi=nmi))