Exemplo n.º 1
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    #  init seed
    my_whole_seed = 222
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedstart + 1):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        model = models.__dict__[args.arch](low_dim=args.low_dim,
                                           multitask=args.multitask,
                                           showfeature=args.showfeature,
                                           domain=args.domain,
                                           args=args)
        model = torch.nn.DataParallel(model).cuda()
        print('Number of learnable params',
              get_learnable_para(model) / 1000000., " M")

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        aug_test = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor(), normalize])

        # load dataset
        # import datasets.fundus_amd_syn_crossvalidation as medicaldata
        import datasets.fundus_amd_syn_crossvalidation_ind as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            worker_init_fn=random.seed(my_whole_seed))

        # define lemniscate and loss function (criterion)
        ndata = train_dataset.__len__()

        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()

        if args.multitaskposrot:
            cls_criterion = nn.CrossEntropyLoss().cuda()
        else:
            cls_criterion = None

        if args.multitaskposrot:
            print("running multi task with miccai")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        elif args.synthesis:
            print("running synthesis")
            criterion = BatchCriterionFour(1, 0.1, args.batch_size,
                                           args).cuda()
        elif args.multiaug:
            print("running cvpr")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        else:
            criterion = nn.CrossEntropyLoss().cuda()

        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                lemniscate = checkpoint['lemniscate']
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if args.evaluate:
            knn_num = 100
            auc, acc, precision, recall, f1score = kNN(args, model, lemniscate,
                                                       train_loader,
                                                       val_loader, knn_num,
                                                       args.nce_t, 2)
            f = open("savemodels/result.txt", "a+")
            f.write("auc: %.4f\n" % (auc))
            f.write("acc: %.4f\n" % (acc))
            f.write("pre: %.4f\n" % (precision))
            f.write("recall: %.4f\n" % (recall))
            f.write("f1score: %.4f\n" % (f1score))
            f.close()
            return

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args, [1000, 2000])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, lemniscate, criterion,
                         cls_criterion, optimizer, epoch, writer)
            writer.add_scalar("train_loss", loss, epoch)

            # save checkpoint
            if epoch % 200 == 0 or (epoch in [1600, 1800, 2000]):
                auc, acc, precision, recall, f1score = kNN(
                    args, model, lemniscate, train_loader, val_loader, 100,
                    args.nce_t, 2)
                # save to txt
                writer.add_scalar("test_auc", auc, epoch)
                writer.add_scalar("test_acc", acc, epoch)
                writer.add_scalar("test_precision", precision, epoch)
                writer.add_scalar("test_recall", recall, epoch)
                writer.add_scalar("test_f1score", f1score, epoch)
                f = open(args.result + "/result.txt", "a+")
                f.write("epoch " + str(epoch) + "\n")
                f.write("auc: %.4f\n" % (auc))
                f.write("acc: %.4f\n" % (acc))
                f.write("pre: %.4f\n" % (precision))
                f.write("recall: %.4f\n" % (recall))
                f.write("f1score: %.4f\n" % (f1score))
                f.close()
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'lemniscate': lemniscate,
                        'optimizer': optimizer.state_dict(),
                    },
                    filename=args.result + "/fold" + str(args.seedstart) +
                    "-epoch-" + str(epoch) + ".pth.tar")
Exemplo n.º 2
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    # if args.pretrained:
    #     print("=> using pre-trained model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch](pretrained=True, finetune=args.finetune, low_dim= args.low_dim)
    # else:
    #     print("=> creating model '{}'".format(args.arch))
    #
    #     model = models.__dict__[args.arch](low_dim=args.low_dim)

    # Data loading code

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # train_dataset = datasets.CombinedMaskDataset(
    #     other_data_path = '/home/saschaho/Simcenter/found_label_imgs',
    #     csv_root_folder='/home/saschaho/Simcenter/Floor_Elevation_Data/Streetview_Irma/Streetview_Irma/images',
    #     data_csv='/home/saschaho/Simcenter/Building_Information_Prediction/all_bims_train.csv',
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(224, scale=(0.2,1.)),
    #         transforms.RandomGrayscale(p=0.2),
    #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]),attribute = 'first_floor_elevation_ft', mask_images=True)

    # val_dataset = datasets.CombinedMaskDataset(
    #         csv_root_folder='/home/saschaho/Simcenter/Floor_Elevation_Data/Streetview_Irma/Streetview_Irma/images',
    #         data_csv='/home/saschaho/Simcenter/Building_Information_Prediction/all_bims_val.csv',
    #     transform=transforms.Compose([
    #         transforms.Resize(256),
    #         transforms.CenterCrop(224),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]),
    #attribute = 'first_floor_elevation_ft', mask_images=True)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.3, 1.)),
        transforms.RandomGrayscale(p=0.5),
        transforms.ColorJitter(0.5, 0.5, 0.5, 0.5),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(), normalize
    ])

    val_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(), normalize])

    train_dataset = First_Floor_Binary(args.attribute_name,
                                       args.train_data,
                                       args.image_folder,
                                       transform=train_transform,
                                       regression=args.regression,
                                       mask_buildings=args.mask_buildings,
                                       softmask=args.softmask)
    val_dataset = First_Floor_Binary(args.attribute_name,
                                     args.val_data,
                                     args.image_folder,
                                     transform=val_transform,
                                     regression=args.regression,
                                     mask_buildings=args.mask_buildings,
                                     softmask=args.softmask)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    model = ResidualAttentionModel_92_Small(args.low_dim, dropout=False)
    model = torch.nn.DataParallel(model).cuda()

    print('Train dataset instances: {}'.format(len(train_loader.dataset)))
    print('Val dataset instances: {}'.format(len(val_loader.dataset)))
    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t,
                                args.nce_m).cuda()
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    #optimizer = RAdam(model.parameters())

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']

            keyname = [
                keyname for keyname in model.state_dict().keys()
                if 'fc.weight' in keyname
            ][0]
            lat_vec_len_model = model.state_dict()[keyname].shape[0]
            lat_vec_len_checkpoint = checkpoint['state_dict'][keyname].shape[0]

            low_dim_differ = False
            if lat_vec_len_model != lat_vec_len_checkpoint:
                low_dim_differ = True
                print(
                    'Warning: Latent vector sizes do not match. Assuming finetuning'
                )
                print(
                    'Lemniscate will be trained from scratch with new optimizer.'
                )
                del checkpoint['state_dict'][keyname]
                del checkpoint['state_dict'][keyname.replace('weight', 'bias')]

            missing_keys, unexpected_keys = model.load_state_dict(
                checkpoint['state_dict'], strict=False)
            if len(missing_keys) or len(unexpected_keys):
                print('Warning: Missing or unexpected keys found.')
                print('Missing: {}'.format(missing_keys))
                print('Unexpected: {}'.format(unexpected_keys))

            if not low_dim_differ:
                # The memory bank will be trained from scratch if
                # the low dim is different. Maybe later repopulated
                lemniscate = checkpoint['lemniscate']
                optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
        return

    for epoch in range(args.start_epoch, args.epochs):
        # if args.distributed:
        #     train_sampler.set_epoch(epoch)
        #adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, lemniscate, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.name)
    # evaluate KNN after last epoch
    kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
Exemplo n.º 3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](low_dim=args.low_dim)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)


    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolderInstance(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2,1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolderInstance(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m).cuda()
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, lemniscate, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'lemniscate': lemniscate,
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)
    # evaluate KNN after last epoch
    kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
Exemplo n.º 4
0
                                         batch_size=100,
                                         shuffle=False,
                                         num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')
ndata = trainset.__len__()

print('==> Building model..')
net = models.__dict__['ResNet34'](low_dim=args.low_dim)
# define leminiscate
if args.nce_k > 0:
    lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t,
                            args.nce_m)
else:
    lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m)

if device == 'cuda':
    net = torch.nn.DataParallel(net,
                                device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

# Model
if args.test_only or len(args.resume) > 0:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/' + args.resume)
    net.load_state_dict(checkpoint['net'])
    lemniscate = checkpoint['lemniscate']
    best_acc = checkpoint['acc']
Exemplo n.º 5
0
def main():
    global args, best_prec1, best_prec1_past, best_prec1_future
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](low_dim=args.low_dim)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.to(get_device(args.gpu))
        else:
            model = torch.nn.DataParallel(model).to(get_device(args.gpu))
    else:
        model.to(get_device(args.gpu))
        model = torch.nn.parallel.DistributedDataParallel(model)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = Dataset(traindir, n_frames)
    val_dataset = Dataset(valdir, n_frames)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,  #(train_sampler is None), 
        num_workers=args.workers)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers)

    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.gpu, args.low_dim, ndata, args.nce_k,
                                args.nce_t,
                                args.nce_m).to(get_device(args.gpu))
        criterion = NCECriterion(ndata).to(get_device(args.gpu))
    else:
        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).to(get_device(args.gpu))
        criterion = nn.CrossEntropyLoss().to(get_device(args.gpu))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
        return

    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, lemniscate, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1, prec1_past, prec1_future = NN(epoch, model, lemniscate,
                                             train_loader, val_loader)

        add_epoch_score('epoch_scores.txt', epoch, prec1)
        add_epoch_score('epoch_scores_past.txt', epoch, prec1_past)
        add_epoch_score('epoch_scores_future.txt', epoch, prec1_future)

        # Sascha: This is a bug because it seems prec1 or best_prec1 is a vector at some point with
        # more than one entry
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, epoch)

        is_best_past = prec1_past > best_prec1_past
        best_prec1_past = max(prec1_past, best_prec1_past)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1_past': best_prec1_past,
                'optimizer': optimizer.state_dict(),
            },
            is_best_past,
            epoch,
            best_mod='_past')

        is_best_future = prec1_future > best_prec1_future
        best_prec1_future = max(prec1_future, best_prec1_future)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1_future': best_prec1_future,
                'optimizer': optimizer.state_dict(),
            },
            is_best_future,
            epoch,
            best_mod='_future')
    # evaluate KNN after last epoch
    kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
Exemplo n.º 6
0
# Model
if args.test_only or len(args.resume)>0:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/'+args.resume)
    net = checkpoint['net']
    lemniscate = checkpoint['lemniscate']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
else:
    print('==> Building model..')
    net = models.__dict__['ResNet50'](low_dim=args.low_dim)
    # define leminiscate
    lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum)

# define loss function
criterion = NCACrossEntropy(torch.LongTensor(trainloader.dataset.targets))

if use_cuda:
    net.cuda()
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    lemniscate.cuda()
    criterion.cuda()
    cudnn.benchmark = True

if args.test_only:
    acc = kNN(0, net, lemniscate, trainloader, testloader, 30, args.temperature)
    sys.exit(0)
Exemplo n.º 7
0
def full_training(args):
    if not os.path.isdir(args.expdir):
        os.makedirs(args.expdir)
    elif os.path.exists(args.expdir + '/results.npy'):
        return

    if 'ae' in args.task:
        os.mkdir(args.expdir + '/figs/')

    train_batch_size = args.train_batch_size // 4 if args.task == 'rot' else args.train_batch_size
    test_batch_size = args.test_batch_size // 4 if args.task == 'rot' else args.test_batch_size
    yield_indices = (args.task == 'inst_disc')
    datadir = args.datadir + args.dataset
    trainloader, valloader, num_classes = general_dataset_loader.prepare_data_loaders(
        datadir,
        image_dim=args.image_dim,
        yield_indices=yield_indices,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
        train_on_10_percent=args.train_on_10,
        train_on_half_classes=args.train_on_half)
    _, testloader, _ = general_dataset_loader.prepare_data_loaders(
        datadir,
        image_dim=args.image_dim,
        yield_indices=yield_indices,
        train_batch_size=train_batch_size,
        test_batch_size=test_batch_size,
    )

    args.num_classes = num_classes
    if args.task == 'rot':
        num_classes = 4
    elif args.task == 'inst_disc':
        num_classes = args.low_dim

    if args.task == 'ae':
        net = models.AE([args.code_dim], image_dim=args.image_dim)
    elif args.task == 'jigsaw':
        net = JigsawModel(num_perms=args.num_perms,
                          code_dim=args.code_dim,
                          gray_prob=args.gray_prob,
                          image_dim=args.image_dim)
    else:
        net = models.resnet26(num_classes,
                              mlp_depth=args.mlp_depth,
                              normalize=(args.task == 'inst_disc'))
    if args.task == 'inst_disc':
        train_lemniscate = LinearAverage(args.low_dim,
                                         trainloader.dataset.__len__(),
                                         args.nce_t, args.nce_m)
        train_lemniscate.cuda()
        args.train_lemniscate = train_lemniscate
        test_lemniscate = LinearAverage(args.low_dim,
                                        valloader.dataset.__len__(),
                                        args.nce_t, args.nce_m)
        test_lemniscate.cuda()
        args.test_lemniscate = test_lemniscate
    if args.source:
        try:
            old_net = torch.load(args.source)
        except:
            print("Falling back encoding")
            from functools import partial
            import pickle
            pickle.load = partial(pickle.load, encoding="latin1")
            pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
            old_net = torch.load(args.source,
                                 map_location=lambda storage, loc: storage,
                                 pickle_module=pickle)

        # net.load_state_dict(old_net['net'].state_dict())
        old_net = old_net['net']
        if hasattr(old_net, "module"):
            old_net = old_net.module
        old_state_dict = old_net.state_dict()
        new_state_dict = net.state_dict()
        for key, weight in old_state_dict.items():
            if 'linear' not in key:
                new_state_dict[key] = weight
            elif key == 'linears.0.weight' and weight.shape[0] == num_classes:
                new_state_dict['linears.0.0.weight'] = weight
            elif key == 'linears.0.bias' and weight.shape[0] == num_classes:
                new_state_dict['linears.0.0.bias'] = weight
        net.load_state_dict(new_state_dict)

        del old_net
    net = torch.nn.DataParallel(net).cuda()
    start_epoch = 0
    if args.task in ['ae', 'inst_disc']:
        best_acc = np.inf
    else:
        best_acc = -1
    results = np.zeros((4, start_epoch + args.nb_epochs))

    net.cuda()
    cudnn.benchmark = True

    if args.task in ['ae']:
        args.criterion = nn.MSELoss()
    else:
        args.criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       net.parameters()),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.wd)

    print("Start training")
    train_func = eval('utils_pytorch.train_' + args.task)
    test_func = eval('utils_pytorch.test_' + args.task)
    if args.test_first:
        with torch.no_grad():
            test_func(0, valloader, net, best_acc, args, optimizer)
    for epoch in range(start_epoch, start_epoch + args.nb_epochs):
        utils_pytorch.adjust_learning_rate(optimizer, epoch, args)
        st_time = time.time()

        # Training and validation
        train_acc, train_loss = train_func(epoch, trainloader, net, args,
                                           optimizer)
        test_acc, test_loss, best_acc = test_func(epoch, valloader, net,
                                                  best_acc, args, optimizer)

        # Record statistics
        results[0:2, epoch] = [train_loss, train_acc]
        results[2:4, epoch] = [test_loss, test_acc]
        np.save(args.expdir + '/results.npy', results)
        print('Epoch lasted {0}'.format(time.time() - st_time))
        sys.stdout.flush()
        if (args.task == 'rot') and (train_acc >= 98) and args.early_stopping:
            break
    if args.task == 'inst_disc':
        args.train_lemniscate = None
        args.test_lemniscate = None
    else:
        best_net = torch.load(args.expdir + 'checkpoint.t7')['net']
        if args.task in ['ae', 'inst_disc']:
            best_acc = np.inf
        else:
            best_acc = -1
        final_acc, final_loss, _ = test_func(0, testloader, best_net, best_acc,
                                             args, None)
Exemplo n.º 8
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    # Initialize distributed processing
    args.distributed = args.world_size > 1
    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           low_dim=args.low_dim)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](low_dim=args.low_dim)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet stats
        std=[0.229, 0.224, 0.225])
    #    normalize = transforms.Normalize(mean=[0.234, 0.191, 0.159],  # xView stats
    #                                     std=[0.173, 0.143, 0.127])

    print("Creating datasets")
    cj = args.color_jit
    train_dataset = datasets.ImageFolderInstance(
        traindir,
        transforms.Compose([
            transforms.Resize((224, 224)),
            #            transforms.Grayscale(3),
            #            transforms.ColorJitter(cj, cj, cj, cj), #transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(45),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    elif args.balanced_sampling:

        print("Using balanced sampling")
        # Here's where we compute the weights for WeightedRandomSampler
        class_counts = {v: 0 for v in train_dataset.class_to_idx.values()}
        for path, ndx in train_dataset.samples:
            class_counts[ndx] += 1
        total = float(np.sum([v for v in class_counts.values()]))
        class_probs = [
            class_counts[ndx] / total for ndx in range(len(class_counts))
        ]

        # make a list of class probabilities corresponding to the entries in train_dataset.samples
        reciprocal_weights = [
            class_probs[idx]
            for i, (_, idx) in enumerate(train_dataset.samples)
        ]

        # weights are the reciprocal of the above
        weights = (1 / torch.Tensor(reciprocal_weights))

        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            weights, len(train_dataset), replacement=True)
    else:
        #if args.red_data is < 1, then the training is done with a subsamle of the total data. Otherwise it's the total data.
        data_size = len(train_dataset)
        sub_index = np.random.randint(0, data_size,
                                      round(args.red_data * data_size))
        sub_index.sort()
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(sub_index)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    print("Training on", len(train_dataset.imgs),
          "images. Training batch size:", args.batch_size)

    if len(train_dataset.imgs) % args.batch_size != 0:
        print(
            "Warning: batch size doesn't divide the # of training images so ",
            len(train_dataset.imgs) % args.batch_size,
            "images will be skipped per epoch.")
        print("If you don't want to skip images, use a batch size in:",
              get_factors(len(train_dataset.imgs)))

    val_dataset = datasets.ImageFolderInstance(
        valdir,
        transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            normalize,
        ]))

    val_bs = [
        factor for factor in get_factors(len(val_dataset)) if factor < 500
    ][-1]
    val_bs = 100
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=val_bs,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("Validating on", len(val_dataset), "images. Validation batch size:",
          val_bs)

    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t,
                                args.nce_m)
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.static_loss,
                                       verbose=False)
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch']
            #           best_prec1 = checkpoint['best_prec1']
            lemniscate = checkpoint['lemniscate']
            if args.select_load:
                pred = checkpoint['prediction']
            print("=> loaded checkpoint '{}' (epoch {}, best_prec1 )".format(
                args.resume,
                checkpoint['epoch']))  #, checkpoint['best_prec1']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # optionally fine-tune a model trained on a different dataset
    elif args.fine_tune:
        print("=> loading checkpoint '{}'".format(args.fine_tune))
        checkpoint = torch.load(args.fine_tune)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss,
                                   verbose=False)
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.fine_tune, checkpoint['epoch']))
    else:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss,
                                   verbose=False)

    # Optionally recompute memory. If fine-tuning, then we must recompute memory
    if args.recompute_memory or args.fine_tune:

        # Aaron - Experiments show that iterating over torch.utils.data.DataLoader will skip the last few
        # unless the batch size evenly divides size of the data set. This shouldn't be the case
        # according to documentation, there's even a flag for drop_last, but it's not working

        # compute a good batch size for re-computing memory
        memory_bs = [
            factor for factor in get_factors(len(train_loader.dataset))
            if factor < 500
        ][-1]
        print("Recomputing memory using", train_dataset.root,
              "with a batch size of", memory_bs)
        transform_bak = train_loader.dataset.transform
        train_loader.dataset.transform = val_loader.dataset.transform
        temploader = torch.utils.data.DataLoader(
            train_loader.dataset,
            batch_size=memory_bs,
            shuffle=False,
            num_workers=train_loader.num_workers,
            pin_memory=True)
        lemniscate.memory = torch.zeros(len(train_loader.dataset),
                                        args.low_dim).cuda()
        model.eval()
        with torch.no_grad():
            for batch_idx, (inputs, targets,
                            indexes) in enumerate(tqdm.tqdm(temploader)):
                batchSize = inputs.size(0)
                features = model(inputs)
                lemniscate.memory[batch_idx * batchSize:batch_idx * batchSize +
                                  batchSize, :] = features.data
        train_loader.dataset.transform = transform_bak
        model.train()

    cudnn.benchmark = True

    if args.evaluate:
        kNN(model, lemniscate, train_loader, val_loader, args.K, args.nce_t)
        return

    begin_train_time = datetime.datetime.now()

    #    my_knn(model, lemniscate, train_loader, val_loader, args.K, args.nce_t, train_dataset, val_dataset)
    if args.tsne:
        labels = idx_to_name(train_dataset, args.graph_labels)
        tsne(lemniscate, args.tsne, labels)
    if args.pca:
        labels = idx_to_name(train_dataset, args.graph_labels)
        pca(lemniscate, labels)
    if args.view_knn:
        my_knn(model, lemniscate, train_loader, val_loader, args.K, args.nce_t,
               train_dataset, val_dataset)
    if args.kmeans:
        kmeans, yi = kmean(lemniscate, args.kmeans, 500, args.K, train_dataset)
        D, I = kmeans.index.search(lemniscate.memory.data.cpu().numpy(), 1)

        cent_group = {}
        data_cent = {}
        for n, i in enumerate(I):
            if i[0] not in cent_group.keys():
                cent_group[i[0]] = []
            cent_group[i[0]].append(n)
        data_cent[n] = i[0]

        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            cent_group[0])
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

#        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m)
#        criterion = NCECriterion(ndata).cuda()

#    lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m)

    if args.tsne_grid:
        tsne_grid(val_loader, model)
    if args.h_cluster:
        for size in range(2, 3):
            #        size = 20
            kmeans, topk = kmean(lemniscate, size, 500, 10, train_dataset)
            respred = torch.tensor([]).cuda()
            lab, idx = [[] for i in range(2)]
            num = 0
            '''
            for p,index,label in pred:
                respred = torch.cat((respred,p))
                if num == 0:
                    lab = label
                else:
                    lab += label
                idx.append(index)
                num+=1
            '''
            h_cluster(lemniscate, train_dataset, kmeans, topk,
                      size)  #, respred, lab, idx)

#    axis_explore(lemniscate, train_dataset)

#    kmeans_opt(lemniscate, 5)

    if args.select:
        if not args.select_load:
            pred = []

            if args.select_size:
                size = int(args.select_size * ndata)
            else:
                size = round(ndata / 100.0)

            sub_sample = np.random.randint(0, ndata, size=size)
            train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
                sub_sample)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_sampler)

            pred = div_train(train_loader, model, 0, pred)

        pred_features = []
        pred_labels = []
        pred_idx = []

        for inst in pred:
            feat, idx, lab = list(inst)
            pred_features.append(feat)
            pred_labels.append(lab)
            pred_idx.append(idx.data.cpu())

        if args.select_save:

            save_checkpoint(
                {
                    'epoch': args.start_epoch,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'prediction': pred,
                    'lemniscate': lemniscate,
                    'optimizer': optimizer.state_dict(),
                }, 'select.pth.tar')

        min_idx = selection(pred_features, pred_idx, train_dataset,
                            args.select_num, args.select_thresh)

        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(min_idx)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

        lemniscate = NCEAverage(args.low_dim, ndata, 20, args.nce_t,
                                args.nce_m)

        optimizer = torch.optim.SGD(model.parameters(),
                                    0.1,
                                    momentum=0.1,
                                    weight_decay=0.00001)

        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss,
                                   verbose=False)

        for epoch in range(50):
            if args.distributed:
                train_sampler.set_epoch(epoch)
            adjust_learning_rate(optimizer, epoch)

            if epoch % 1 == 0:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'lemniscate': lemniscate,
                    'optimizer': optimizer.state_dict(),
                })

            train(train_loader, model, lemniscate, criterion, optimizer, epoch)

        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(sub_index)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t,
                                args.nce_m)
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss,
                                   verbose=False)

    if args.kmeans_opt:
        kmeans_opt(lemniscate, 500)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        if epoch % 1 == 0:
            # evaluate on validation set
            #prec1 = NN(epoch, model, lemniscate, train_loader, train_loader) # was evaluating on train
            #            prec1 = kNN(model, lemniscate, train_loader, val_loader, args.K, args.nce_t)
            # prec1 really should be renamed to prec5 as kNN now returns top5 score, but
            # it won't be backward's compatible as earlier models were saved with "best_prec1"

            # remember best prec@1 and save checkpoint
            #            is_best = prec1 > best_prec1
            #            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                #                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            })  # , is_best)

        # train for one epoch
        train(train_loader, model, lemniscate, criterion, optimizer, epoch)

#        kmeans,cent = kmeans()
#        group_train(train_loader, model, lemniscate, criterion, optimizer, epoch, kmeans, cent)

# print elapsed time
    end_train_time = datetime.datetime.now()
    d = end_train_time - begin_train_time
    print(
        "Trained for %d epochs. Elapsed time: %s days, %.2dh: %.2dm: %.2ds" %
        (len(range(args.start_epoch, args.epochs)), d.days, d.seconds // 3600,
         (d.seconds // 60) % 60, d.seconds % 60))
Exemplo n.º 9
0
def main(args):

    # Data
    print('==> Preparing data..')
    _size = 32
    transform_train = transforms.Compose([
        transforms.Resize(size=_size),
        transforms.RandomResizedCrop(size=_size, scale=(0.2, 1.)),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(size=_size),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = datasets.CIFAR10Instance(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)

    testset = datasets.CIFAR10Instance(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=4)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')
    ndata = trainset.__len__()

    print('==> Building model..')
    net = models.__dict__['ResNet18'](low_dim=args.low_dim)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cuda':
        net = torch.nn.DataParallel(net,
                                    device_ids=range(
                                        torch.cuda.device_count()))
        cudnn.benchmark = True

    criterion = ICRcriterion()
    # define loss function: inner product loss within each mini-batch
    uel_criterion = BatchCriterion(args.batch_m, args.batch_t, args.batch_size,
                                   ndata)

    net.to(device)
    criterion.to(device)
    uel_criterion.to(device)
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    if args.test_only or len(args.resume) > 0:
        # Load checkpoint.
        model_path = 'checkpoint/' + args.resume
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            args.model_dir), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    # define leminiscate
    if args.test_only and len(args.resume) > 0:

        trainFeatures, feature_index = compute_feature(trainloader, net,
                                                       len(trainset), args)
        lemniscate = LinearAverage(torch.tensor(trainFeatures), args.low_dim,
                                   ndata, args.nce_t, args.nce_m)

    else:

        lemniscate = LinearAverage(torch.tensor([]), args.low_dim, ndata,
                                   args.nce_t, args.nce_m)
    lemniscate.to(device)

    # define optimizer
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    # optimizer2 = torch.optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    # test acc
    if args.test_only:
        acc = kNN(0,
                  net,
                  trainloader,
                  testloader,
                  200,
                  args.batch_t,
                  ndata,
                  low_dim=args.low_dim)
        exit(0)

    if len(args.resume) > 0:
        best_acc = best_acc
        start_epoch = start_epoch + 1
    else:
        best_acc = 0  # best test accuracy
        start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    icr2 = ICRDiscovery(ndata)

    # init_cluster_num = 20000
    for round in range(5):
        for epoch in range(start_epoch, 200):
            #### get Features

            # trainFeatures are trainloader features and shuffle=True, so feature_index is match data
            trainFeatures, feature_index = compute_feature(
                trainloader, net, len(trainset), args)

            if round == 0:
                y = -1 * math.log10(ndata) / 200 * epoch + math.log10(ndata)
                cluster_num = int(math.pow(10, y))
                if cluster_num <= args.nmb_cluster:
                    cluster_num = args.nmb_cluster

                print('cluster number: ' + str(cluster_num))

                ###clustering algorithm to use
                # faiss cluster
                deepcluster = clustering.__dict__[args.clustering](
                    int(cluster_num))

                #### Features to clustering
                clustering_loss = deepcluster.cluster(trainFeatures,
                                                      feature_index,
                                                      verbose=args.verbose)

                L = np.array(deepcluster.images_lists)
                image_dict = deepcluster.images_dict

                print('create ICR ...')
                # icr = ICRDiscovery(ndata)

                # if args.test_only and len(args.resume) > 0:
                # icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset,
                # cluster_ratio + epoch*((1-cluster_ratio)/250))
                icrtime = time.time()

                # icr = cluster_assign(epoch, L, trainFeatures, feature_index, 1, 1)
                if epoch < args.warm_epoch:
                    icr = cluster_assign(epoch, L, trainFeatures,
                                         feature_index, args.cluster_ratio, 1)
                else:
                    icr = PreScore(epoch, L, image_dict, trainFeatures,
                                   feature_index, trainset, args.high_ratio,
                                   args.cluster_ratio, args.alpha, args.beta)

                print('calculate ICR time is: {}'.format(time.time() -
                                                         icrtime))
                writer.add_scalar('icr_time', (time.time() - icrtime),
                                  epoch + round * 200)

            else:
                cluster_num = args.nmb_cluster
                print('cluster number: ' + str(cluster_num))

                ###clustering algorithm to use
                # faiss cluster
                deepcluster = clustering.__dict__[args.clustering](
                    int(cluster_num))

                #### Features to clustering
                clustering_loss = deepcluster.cluster(trainFeatures,
                                                      feature_index,
                                                      verbose=args.verbose)

                L = np.array(deepcluster.images_lists)
                image_dict = deepcluster.images_dict

                print('create ICR ...')
                # icr = ICRDiscovery(ndata)

                # if args.test_only and len(args.resume) > 0:
                # icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset,
                # cluster_ratio + epoch*((1-cluster_ratio)/250))
                icrtime = time.time()

                # icr = cluster_assign(epoch, L, trainFeatures, feature_index, 1, 1)
                icr = PreScore(epoch, L, image_dict, trainFeatures,
                               feature_index, trainset, args.high_ratio,
                               args.cluster_ratio, args.alpha, args.beta)

                print('calculate ICR time is: {}'.format(time.time() -
                                                         icrtime))
                writer.add_scalar('icr_time', (time.time() - icrtime),
                                  epoch + round * 200)

            # else:
            #     icr = cluster_assign(icr, L, trainFeatures, feature_index, trainset, 0.2 + epoch*0.004)

            # print(icr.neighbours)

            icr2 = train(epoch, net, optimizer, lemniscate, criterion,
                         uel_criterion, trainloader, icr, icr2,
                         args.stage_update, args.lr, device, round)

            print('----------Evaluation---------')
            start = time.time()
            acc = kNN(0,
                      net,
                      trainloader,
                      testloader,
                      200,
                      args.batch_t,
                      ndata,
                      low_dim=args.low_dim)
            print("Evaluation Time: '{}'s".format(time.time() - start))

            writer.add_scalar('nn_acc', acc, epoch + round * 200)

            if acc > best_acc:
                print('Saving..')
                state = {
                    'net': net.state_dict(),
                    'acc': acc,
                    'epoch': epoch,
                }
                if not os.path.isdir(args.model_dir):
                    os.mkdir(args.model_dir)
                torch.save(state,
                           './checkpoint/ckpt_best_round_{}.t7'.format(round))

                best_acc = acc

            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            torch.save(state,
                       './checkpoint/ckpt_last_round_{}.t7'.format(round))

            print(
                '[Round]: {} [Epoch]: {} \t accuracy: {}% \t (best acc: {}%)'.
                format(round, epoch, acc, best_acc))
Exemplo n.º 10
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(224, scale=(0.3, 1.)),
            transforms.RandomGrayscale(p=0.5),
            transforms.ColorJitter(0.5, 0.5, 0.5, 0.5),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            normalize])

    val_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            normalize])

    train_dataset = Foundation_Type_Binary(args.train_data, transform=train_transform, mask_buildings=args.mask_buildings)
    val_dataset = Foundation_Type_Binary(args.val_data, transform=val_transform, mask_buildings=args.mask_buildings)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    model = resnet50(low_dim=args.low_dim)
    model = torch.nn.DataParallel(model).cuda()

    print ('Train dataset instances: {}'.format(len(train_loader.dataset)))
    print ('Val dataset instances: {}'.format(len(val_loader.dataset)))

    ndata = train_dataset.__len__()
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.low_dim, ndata, args.nce_k, args.nce_t, args.nce_m).cuda()
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t, args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            args.epochs = args.start_epoch + args.epochs
            best_prec1 = checkpoint['best_prec1']

            missing_keys, unexpected_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
            if len(missing_keys) or len(unexpected_keys):
                print('Warning: Missing or unexpected keys found.')
                print('Missing: {}'.format(missing_keys))
                print('Unexpected: {}'.format(unexpected_keys))

            low_dim_checkpoint = checkpoint['lemniscate'].memory.shape[1]
            if low_dim_checkpoint == args.low_dim:
                lemniscate = checkpoint['lemniscate']
            else:
                print('Chosen low dim does not fit checkpoint. Assuming fine-tuning and not loading memory bank.')
            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
            except ValueError:
                print('Training optimizer does not fit to checkpoint optimizer. Assuming fine-tuning and load optimizer from scratch. ')

            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
        return

    prec1 = NN(0, model, lemniscate, train_loader, val_loader)
    print('Start out precision: {}'.format(prec1))
    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, lemniscate, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'lemniscate': lemniscate,
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best, args.name)
Exemplo n.º 11
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    my_whole_seed = 111
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedend):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        model = models.__dict__[args.arch](low_dim=args.low_dim,
                                           multitask=args.multitask,
                                           showfeature=args.showfeature,
                                           args=args)
        #
        # from models.Gresnet import ResNet18
        # model = ResNet18(low_dim=args.low_dim, multitask=args.multitask)
        model = torch.nn.DataParallel(model).cuda()

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        # aug = transforms.Compose([transforms.RandomRotation(60),
        #                           transforms.RandomResizedCrop(224, scale=(0.6, 1.)),
        #                           transforms.RandomGrayscale(p=0.2),
        #                           transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        #                           transforms.RandomHorizontalFlip(),
        #                           transforms.ToTensor(),
        #                             normalize])
        aug_test = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(), normalize])

        # dataset
        import datasets.fundus_kaggle_dr as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=8,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 test_type="amd",
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))
        valid_dataset_gon = medicaldata.traindataset(root=args.data,
                                                     transform=aug_test,
                                                     train=False,
                                                     test_type="gon",
                                                     args=args)
        val_loader_gon = torch.utils.data.DataLoader(
            valid_dataset_gon,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))
        valid_dataset_pm = medicaldata.traindataset(root=args.data,
                                                    transform=aug_test,
                                                    train=False,
                                                    test_type="pm",
                                                    args=args)
        val_loader_pm = torch.utils.data.DataLoader(
            valid_dataset_pm,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))

        # define lemniscate and loss function (criterion)
        ndata = train_dataset.__len__()

        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        local_lemniscate = None

        if args.multitaskposrot:
            print("running multi task with positive")
            criterion = BatchCriterionRot(1, 0.1, args.batch_size, args).cuda()
        elif args.domain:
            print("running domain with four types--unify ")
            from lib.BatchAverageFour import BatchCriterionFour
            # criterion = BatchCriterionTriple(1, 0.1, args.batch_size, args).cuda()
            criterion = BatchCriterionFour(1, 0.1, args.batch_size,
                                           args).cuda()
        elif args.multiaug:
            print("running multi task")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        else:
            criterion = nn.CrossEntropyLoss().cuda()

        if args.multitask:
            cls_criterion = nn.CrossEntropyLoss().cuda()
        else:
            cls_criterion = None

        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                lemniscate = checkpoint['lemniscate']
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if args.evaluate:
            knn_num = 100
            auc, acc, precision, recall, f1score = kNN(args, model, lemniscate,
                                                       train_loader,
                                                       val_loader, knn_num,
                                                       args.nce_t, 2)
            return

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args, [100, 200])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, lemniscate, local_lemniscate,
                         criterion, cls_criterion, optimizer, epoch, writer)
            writer.add_scalar("train_loss", loss, epoch)

            # gap_int = 10
            # if (epoch) % gap_int == 0:
            #     knn_num = 100
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader, knn_num, args.nce_t, 2)
            #     writer.add_scalar("test_auc", auc, epoch)
            #     writer.add_scalar("test_acc", acc, epoch)
            #     writer.add_scalar("test_precision", precision, epoch)
            #     writer.add_scalar("test_recall", recall, epoch)
            #     writer.add_scalar("test_f1score", f1score, epoch)
            #
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_gon,
            #                                                knn_num, args.nce_t, 2)
            #     writer.add_scalar("gon/test_auc", auc, epoch)
            #     writer.add_scalar("gon/test_acc", acc, epoch)
            #     writer.add_scalar("gon/test_precision", precision, epoch)
            #     writer.add_scalar("gon/test_recall", recall, epoch)
            #     writer.add_scalar("gon/test_f1score", f1score, epoch)
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_pm,
            #                                                knn_num, args.nce_t, 2)
            #     writer.add_scalar("pm/test_auc", auc, epoch)
            #     writer.add_scalar("pm/test_acc", acc, epoch)
            #     writer.add_scalar("pm/test_precision", precision, epoch)
            #     writer.add_scalar("pm/test_recall", recall, epoch)
            #     writer.add_scalar("pm/test_f1score", f1score, epoch)

            # save checkpoint
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'lemniscate': lemniscate,
                    'optimizer': optimizer.state_dict(),
                },
                filename=args.result + "/fold" + str(args.seedstart) +
                "-epoch-" + str(epoch) + ".pth.tar")
Exemplo n.º 12
0
def build_model():
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    if args.architecture == 'resnet18':
        net = models.__dict__['resnet18_cifar'](low_dim=args.low_dim)
    elif args.architecture == 'wrn-28-2':
        net = models.WideResNet(depth=28,
                                num_classes=args.low_dim,
                                widen_factor=2,
                                dropRate=0).to(args.device)
    elif args.architecture == 'wrn-28-10':
        net = models.WideResNet(depth=28,
                                num_classes=args.low_dim,
                                widen_factor=10,
                                dropRate=0).to(args.device)

    # define leminiscate
    if args.nce_k > 0:
        lemniscate = NCEAverage(args.low_dim, args.ndata, args.nce_k,
                                args.nce_t, args.nce_m)
    else:
        lemniscate = LinearAverage(args.low_dim, args.ndata, args.nce_t,
                                   args.nce_m)

    if args.device == 'cuda':
        net = torch.nn.DataParallel(net,
                                    device_ids=range(
                                        torch.cuda.device_count()))
        cudnn.benchmark = True

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.weight_decay,
                          nesterov=True)
    # Model
    if args.test_only or len(args.resume) > 0:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(args.resume)
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lemniscate = checkpoint['lemniscate']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1

    if args.lr_scheduler == 'multi-step':
        if args.epochs == 200:
            steps = [60, 120, 160]
        elif args.epochs == 600:
            steps = [180, 360, 480, 560]
        else:
            raise RuntimeError(
                f"need to config steps for epoch = {args.epochs} first.")
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             steps,
                                             gamma=0.2,
                                             last_epoch=start_epoch - 1)
    elif args.lr_scheduler == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   args.epochs,
                                                   eta_min=0.00001,
                                                   last_epoch=start_epoch - 1)
    elif args.lr_scheduler == 'cosine-with-restart':
        scheduler = CosineAnnealingLRWithRestart(optimizer,
                                                 eta_min=0.00001,
                                                 last_epoch=start_epoch - 1)
    else:
        raise ValueError("not supported")

    # define loss function
    if hasattr(lemniscate, 'K'):
        criterion = NCECriterion(args.ndata)
    else:
        criterion = nn.CrossEntropyLoss()

    net.to(args.device)
    lemniscate.to(args.device)
    criterion.to(args.device)

    return net, lemniscate, optimizer, criterion, scheduler, best_acc, start_epoch
Exemplo n.º 13
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](low_dim=args.low_dim)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolderInstance(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_labels = torch.tensor(train_dataset.targets).long().cuda()
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=None)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolderInstance(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define lemniscate and loss function (criterion)
    ndata = train_dataset.__len__()
    lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                               args.nce_m).cuda()
    rlb = ReliableSearch(ndata, args.low_dim, args.threshold_1,
                         args.threshold_2, args.batch_size).cuda()
    criterion = ReliableCrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = 0
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        kNN(0, model, lemniscate, train_loader, val_loader, 200, args.nce_t)
        return

    for rnd in range(args.start_round, args.rounds):

        if rnd > 0:
            memory = recompute_memory(model, lemniscate, train_loader,
                                      val_loader, args.batch_size,
                                      args.workers)
            num_reliable_1, consistency_1, num_reliable_2, consistency_2 = rlb.update(
                memory, train_labels)
            print(
                'Round [%02d/%02d]\tReliable1: %.12f\tReliable2: %.12f\tConsistency1: %.12f\tConsistency2: %.12f'
                % (rnd, args.rounds, num_reliable_1, num_reliable_2,
                   consistency_1, consistency_2))

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch)

            # train for one epoch
            train(train_loader, model, lemniscate, rlb, criterion, optimizer,
                  epoch)

            # evaluate on validation set
            prec1 = NN(epoch, model, lemniscate, train_loader, val_loader)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'lemniscate': lemniscate,
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                    #}, is_best, filename='ckpts/%02d-%04d-checkpoint.pth.tar'%(rnd+1, epoch + 1))
                },
                is_best)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'lemniscate': lemniscate,
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best=False,
            filename='ckpts/%02d-checkpoint.pth.tar' % (rnd + 1))

        # evaluate KNN after last epoch
        top1, top5 = kNN(0, model, lemniscate, train_loader, val_loader, 200,
                         args.nce_t)
        print('Round [%02d/%02d]\tTop1: %.2f\tTop5: %.2f' %
              (rnd + 1, args.rounds, top1, top5))
Exemplo n.º 14
0
def main():
    global args, best_prec1, min_avgloss
    args = parser.parse_args()
    input("Begin the {} time's training".format(args.train_time))
    writer_log_dir = "/data/fhz/unsupervised_recommendation/idfe_runs/idfe_train_time:{}".format(
        args.train_time)
    writer = SummaryWriter(log_dir=writer_log_dir)
    if args.dataset == "lung":
        # build dataloader,val_dloader will be build in test function
        model = idfe.IdFe3d(feature_dim=args.latent_dim)
        model.encoder = torch.nn.DataParallel(model.encoder)
        model.linear_map = torch.nn.DataParallel(model.linear_map)
        model = model.cuda()
        train_datalist, test_datalist = multi_cross_validation()
        ndata = len(train_datalist)
    elif args.dataset == "gland":
        dataset_path = "/data/fhz/MICCAI2015/npy"
        model = idfe.IdFe2d(feature_dim=args.latent_dim)
        model.encoder = torch.nn.DataParallel(model.encoder)
        model.linear_map = torch.nn.DataParallel(model.linear_map)
        model = model.cuda()
        train_datalist = glob(path.join(dataset_path, "train", "*.npy"))
        ndata = len(train_datalist)
    else:
        raise FileNotFoundError("Dataset {} Not Found".format(args.dataset))
    if args.nce_k > 0:
        """
        Here we use NCE to calculate loss
        """
        lemniscate = NCEAverage(args.latent_dim, ndata, args.nce_k, args.nce_t,
                                args.nce_m).cuda()
        criterion = NCECriterion(ndata).cuda()
    else:
        lemniscate = LinearAverage(args.latent_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            min_avgloss = checkpoint['min_avgloss']
            model.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            model.linear_map.load_state_dict(
                checkpoint['linear_map_state_dict'])
            lemniscate = checkpoint['lemniscate']
            optimizer.load_state_dict(checkpoint['optimizer'])
            train_datalist = checkpoint['train_datalist']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.dataset == "lung":
        train_dset = LungDataSet(data_path_list=train_datalist,
                                 augment_prob=args.aug_prob)
        train_dloader = DataLoader(dataset=train_dset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)
    elif args.dataset == "gland":
        train_dset = GlandDataset(data_path_list=train_datalist,
                                  need_seg_label=False,
                                  augment_prob=args.aug_prob)
        train_dloader = DataLoader(dataset=train_dset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers,
                                   pin_memory=True)
    else:
        raise FileNotFoundError("Dataset {} Not Found".format(args.dataset))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        epoch_loss = train(train_dloader,
                           model=model,
                           lemniscate=lemniscate,
                           criterion=criterion,
                           optimizer=optimizer,
                           epoch=epoch,
                           writer=writer,
                           dataset=args.dataset)
        if (epoch + 1) % 5 == 0:
            if args.dataset == "lung":
                """
                Here we define the best point as the minimum average epoch loss
                
                """
                accuracy = list([])
                # for i in range(5):
                #     train_feature = lemniscate.memory.clone()
                #     test_datalist = train_datalist[five_cross_idx[i]:five_cross_idx[i + 1]]
                #     test_feature = train_feature[five_cross_idx[i]:five_cross_idx[i + 1], :]
                #     train_indices = [train_datalist.index(d) for d in train_datalist if d not in test_datalist]
                #     tmp_train_feature = torch.index_select(train_feature, 0, torch.tensor(train_indices).cuda())
                #     tmp_train_datalist = [train_datalist[i] for i in train_indices]
                #     test_label = np.array(
                #         [int(eval(re.match("(.*)_(.*)_annotations.npy", path.basename(raw_cube_path)).group(2)) > 3)
                #          for raw_cube_path in test_datalist], dtype=np.float)
                #     tmp_train_label = np.array(
                #         [int(eval(re.match("(.*)_(.*)_annotations.npy", path.basename(raw_cube_path)).group(2)) > 3)
                #          for raw_cube_path in tmp_train_datalist], dtype=np.float)
                #     accuracy.append(
                #         kNN(tmp_train_feature, tmp_train_label, test_feature, test_label, K=20, sigma=1 / 10))
                # accuracy = mean(accuracy)
                is_best = (epoch_loss < min_avgloss)
                min_avgloss = min(epoch_loss, min_avgloss)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        "train_time": args.train_time,
                        "encoder_state_dict": model.encoder.state_dict(),
                        "linear_map_state_dict": model.linear_map.state_dict(),
                        'lemniscate': lemniscate,
                        'min_avgloss': min_avgloss,
                        'dataset': args.dataset,
                        'optimizer': optimizer.state_dict(),
                        'train_datalist': train_datalist
                    }, is_best)
                # knn_text = "In epoch :{} the five cross validation accuracy is :{}".format(epoch, accuracy * 100.0)
                # # print(knn_text)
                # writer.add_text("knn/text", knn_text, epoch)
                # writer.add_scalar("knn/accuracy", accuracy, global_step=epoch)
            elif args.dataset == "gland":
                is_best = (epoch_loss < min_avgloss)
                min_avgloss = min(epoch_loss, min_avgloss)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        "train_time": args.train_time,
                        "encoder_state_dict": model.encoder.state_dict(),
                        "linear_map_state_dict": model.linear_map.state_dict(),
                        'lemniscate': lemniscate,
                        'min_avgloss': min_avgloss,
                        'dataset': args.dataset,
                        'optimizer': optimizer.state_dict(),
                        'train_datalist': train_datalist,
                    }, is_best)