コード例 #1
0
ファイル: eval.py プロジェクト: deskmel/SimCLR
def load_model(checkpoints_folder,device):
    model =ResNetSimCLR(**config['model'])
    model.eval()
    state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model = model.to(device)
    return model
コード例 #2
0
ファイル: run.py プロジェクト: ShahabBakht/SimCLR
def main():
    args = parser.parse_args()
    assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
    # check if gpu training is available
    if not args.disable_cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
        cudnn.deterministic = True
        cudnn.benchmark = True
    else:
        args.device = torch.device('cpu')
        args.gpu_index = -1

    dataset = ContrastiveLearningDataset(args.data)

    train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    model = ResNetSimCLR(base_model=args.arch,
                         out_dim=args.out_dim).to(args.device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            args.start_iter = checkpoint['iteration']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))
    else:
        args.start_iter = 0
        args.start_epoch = 0

    #  It’s a no-op if the 'gpu_index' argument is a negative integer or None.
    with torch.cuda.device(args.gpu_index):
        simclr = SimCLR(model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        args=args)
        simclr.train(train_loader)
コード例 #3
0
ファイル: encode.py プロジェクト: webMan1/SimCLR
def encode(save_root, model_file, data_folder, model_name='ca', dataset_name='celeba', batch_size=64, device='cuda:0', out_dim=256):
    os.makedirs(save_root, exist_ok=True)
    os.makedirs(data_folder, exist_ok=True)

    if dataset_name == 'celeba':
        train_loader = DataLoader(datasets.CelebA(data_folder, split='train', download=True, transform=transforms.ToTensor()),
                                    batch_size=batch_size, shuffle=False)
        valid_loader = DataLoader(datasets.CelebA(data_folder, split='valid', download=True, transform=transforms.ToTensor()),
                                    batch_size=batch_size, shuffle=False)
    elif dataset_name == 'stanfordCars':
        t = transforms.Compose([
            transforms.Resize(512),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3,1,1) if x.shape[0] == 1 else x)
        ])
        train_data_dir = os.path.join(data_folder, 'cars_train/')
        train_annos = os.path.join(data_folder, 'devkit/cars_train_annos.mat')
        train_loader = DataLoader(CarsDataset(train_annos, train_data_dir, t), batch_size=batch_size, shuffle=False)
        valid_data_dir = os.path.join(data_folder, 'cars_test/')
        valid_annos = os.path.join(data_folder, 'devkit/cars_test_annos_withlabels.mat')
        valid_loader = DataLoader(CarsDataset(valid_annos, valid_data_dir, t), batch_size=batch_size, shuffle=False)
    elif dataset_name == 'compCars':
        t = transforms.Compose([
            transforms.Resize(512),
            transforms.CenterCrop(512),
            transforms.ToTensor()
        ])
        train_loader = DataLoader(CompCars(data_folder, True, t), batch_size=batch_size, shuffle=False)
        valid_loader = DataLoader(CompCars(data_folder, False, t), batch_size=batch_size, shuffle=False)


    model = ResNetSimCLR('resnet50', out_dim)
    model.load_state_dict(torch.load(model_file, map_location=device))
    model = model.to(device)
    model.eval()

    print('Starting on training data')
    train_encodings = []
    for x, _ in train_loader:
        x = x.to(device)
        h, _ = model(x)
        train_encodings.append(h.cpu().detach())
    torch.save(torch.cat(train_encodings, dim=0), os.path.join(save_root, f'{dataset_name}-{model_name}model-train_encodings.pt'))

    print('Starting on validation data')
    valid_encodings = []
    for x, _ in valid_loader:
        x = x.to(device)
        h, _ = model(x)
        if len(h.shape) == 1:
            h = h.unsqueeze(0)
        valid_encodings.append(h.cpu().detach())
    torch.save(torch.cat(valid_encodings, dim=0), os.path.join(save_root, f'{dataset_name}-{model_name}model-valid_encodings.pt'))
コード例 #4
0
ファイル: finetune.py プロジェクト: webMan1/SimCLR
def run(config):
    model = ResNetSimCLR('resnet50', config.out_dim)
    model.load_state_dict(
        torch.load(config.model_file, map_location=config.device))
    model = model.to(config.device)
    clf = nn.Linear(2048, 196)
    full = FineTuner(model, clf)
    optim = torch.optim.Adam(list(model.parameters()) + list(clf.parameters()))
    objective = nn.CrossEntropyLoss()

    t = T.Compose([
        T.Resize(512),
        T.CenterCrop(512),
        T.ToTensor(),
        T.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x)
    ])
    train_data_dir = os.path.join(config.data_root, 'cars_train/')
    train_annos = os.path.join(config.data_root, 'devkit/cars_train_annos.mat')
    train_dataset = StanfordCarsMini(train_annos, train_data_dir, t)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)

    valid_data_dir = os.path.join(config.data_root, 'cars_test/')
    valid_annos = os.path.join(config.data_root,
                               'devkit/cars_test_annos_withlabels.mat')
    valid_dataset = CarsDataset(valid_annos, valid_data_dir, t)
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size)

    solver = CESolver(full,
                      train_loader,
                      valid_loader,
                      config.save_root,
                      name=config.name,
                      device=config.device)
    solver.train(config.num_epochs)
コード例 #5
0
def evaluation(checkpoints_folder, config, device):
    model = ResNetSimCLR(**config['model'])
    model.eval()
    model.load_state_dict(
        torch.load(os.path.join(checkpoints_folder, 'model.pth')))
    model = model.to(device)

    train_set = torchvision.datasets.CIFAR10(
        root='../data/CIFAR10',
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        train=True,
        download=True)
    test_set = torchvision.datasets.CIFAR10(root='../data/CIFAR10',
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize(
                                                    mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225])
                                            ]),
                                            train=False,
                                            download=True)
    # num_train = len(train_dataset)
    # indices = list(range(num_train))
    # np.random.shuffle(indices)
    #
    # split = int(np.floor(0.05 * num_train))
    # train_idx, test_idx = indices[split:], indices[:split]
    #
    # # define samplers for obtaining training and validation batches
    # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
    # test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_idx)  # ?????sampler????????????

    # train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], drop_last=True, shuffle=True)
    #
    # test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], drop_last=True, shuffle=True)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=48,
                                               drop_last=True,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=48,
                                              drop_last=True,
                                              shuffle=True)

    X_train_feature = []
    label_train = []

    for data in train_loader:
        x, y = data
        x = x.to(device)
        features, _ = model(x)
        X_train_feature.extend(features.cpu().detach().numpy())
        label_train.extend(y.cpu().detach().numpy())

    X_train_feature = np.array(X_train_feature)
    label_train = np.array(label_train)

    X_test_feature = []
    label_test = []
    for data in test_loader:
        x, y = data
        x = x.to(device)
        features, _ = model(x)
        X_test_feature.extend(features.cpu().detach().numpy())
        label_test.extend(y.cpu().detach().numpy())
    X_test_feature = np.array(X_test_feature)
    label_test = np.array(label_test)
    scaler = preprocessing.StandardScaler()
    print('ok')
    scaler.fit(X_train_feature)
    # print(X_test_feature.shape)
    # print(y_test.shape)
    linear_model_eval(scaler.transform(X_train_feature), label_train,
                      scaler.transform(X_test_feature), label_test)
コード例 #6
0
ファイル: estimate.py プロジェクト: isadrtdinov/SimCLR
def main():
    args = parser.parse_args()
    assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
    # check if gpu training is available
    if not args.disable_cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
        cudnn.deterministic = True
        cudnn.benchmark = True
    else:
        args.device = torch.device('cpu')
        args.gpu_index = -1

    if args.mode == 'simclr':
        dataset = ContrastiveLearningDataset(args.data)
        train_dataset = dataset.get_dataset(args.dataset_name,
                                            args.n_views,
                                            train=True)
        model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
        trainer_class = SimCLRTrainer
    elif args.mode == 'supervised':
        dataset = SupervisedLearningDataset(args.data)
        train_dataset = dataset.get_dataset(args.dataset_name,
                                            args.supervised_augments,
                                            train=True)
        model = ResNetSimCLR(base_model=args.arch,
                             out_dim=len(train_dataset.classes))
        trainer_class = SupervisedTrainer
    else:
        raise InvalidTrainingMode()

    if args.target_shuffle is not None:
        random.seed(args.target_shuffle)
        random.shuffle(train_dataset.targets)

    checkpoints = []
    for root, dirs, files in os.walk(
            os.path.join('experiments', args.experiment_group, 'wandb')):
        for file in files:
            if file == args.estimate_checkpoint:
                checkpoints += [os.path.join(root, file)]

    set_random_seed(args.seed)
    sample_indices = torch.randint(len(train_dataset),
                                   size=(args.batch_size *
                                         args.estimate_batches, ))

    #  It’s a no-op if the 'gpu_index' argument is a negative integer or None.
    estimated_prob, estimated_argmax = [], []
    with torch.cuda.device(args.gpu_index):
        for file in checkpoints:
            state = torch.load(file)
            model.load_state_dict(state['model'])
            model.eval()
            trainer = trainer_class(model=model,
                                    optimizer=None,
                                    scheduler=None,
                                    args=args)

            checkpoint_prob, checkpoint_argmax = [], []
            for i in range(args.estimate_batches):
                if args.fixed_augments:
                    set_random_seed(args.seed)

                if args.mode == 'simclr':
                    images = [[], []]
                    for index in sample_indices[i:i + args.batch_size]:
                        example = train_dataset[index][0]
                        images[0] += [example[0]]
                        images[1] += [example[1]]

                    images[0] = torch.stack(images[0], dim=0)
                    images[1] = torch.stack(images[1], dim=0)
                    labels = None
                elif args.mode == 'supervised':
                    images, labels = [], []
                    for index in sample_indices[i:i + args.batch_size]:
                        example = train_dataset[index]
                        images += [example[0]]
                        labels += [example[1]]

                    images = torch.stack(images, dim=0)
                    labels = torch.tensor(labels, dtype=torch.long)

                with torch.no_grad():
                    logits, labels = trainer.calculate_logits(images, labels)

                    prob = torch.softmax(logits,
                                         dim=1)[torch.arange(labels.shape[0]),
                                                labels]
                    checkpoint_prob += [prob.detach().cpu()]

                    argmax = (torch.argmax(logits,
                                           dim=1) == labels).to(torch.int)
                    checkpoint_argmax += [argmax.detach().cpu()]

            checkpoint_prob = torch.cat(checkpoint_prob, dim=0)
            estimated_prob += [checkpoint_prob]

            checkpoint_argmax = torch.cat(checkpoint_argmax, dim=0)
            estimated_argmax += [checkpoint_argmax]

    estimated_prob = torch.stack(estimated_prob, dim=0)
    estimated_argmax = torch.stack(estimated_argmax, dim=0)
    torch.save(
        {
            'indices': sample_indices,
            'prob': estimated_prob,
            'argmax': estimated_argmax
        }, os.path.join('experiments', args.experiment_group, args.out_file))
コード例 #7
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print(args.gpu)
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    print("=> creating model '{}'".format(args.arch))

    model = ResNetSimCLR(base_model=args.arch,
                         out_dim=args.out_dim).to(args.gpu)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.num_workers = int(
                (args.num_workers + ngpus_per_node - 1) / ngpus_per_node)
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        #raise NotImplementedError("Only DistributedDataParallel is supported.")
    #else:
    # AllGather implementation (batch shuffle, queue update, etc.) in
    # this code only supports DistributedDataParallel.
    #raise NotImplementedError("Only DistributedDataParallel is supported.")

    # Data loader
    train_loader, train_sampler = data_loader(args.dataset,
                                              args.data_path,
                                              args.batch_size,
                                              args.num_workers,
                                              download=args.download,
                                              distributed=args.distributed,
                                              supervised=False)

    #optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=args.weight_decay)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epochs,
                                                           eta_min=0,
                                                           last_epoch=-1)

    criterion = NTXentLoss(args.gpu, args.batch_size, args.temperature,
                           True).cuda(args.gpu)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if apex_support and args.fp16_precision:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O2',
                                          keep_batchnorm_fp32=True)

    cudnn.benchmark = True

    train(model, train_loader, train_sampler, criterion, optimizer, scheduler,
          args, ngpus_per_node)
コード例 #8
0
ファイル: validate.py プロジェクト: dy0607/SimCLR
        res = torch.argmax(out, 1)
        correct += (res == labels).sum().item()
        tot += labels.size(0)

    return 100 - 100 * correct / tot, test_loss


config_path = "./best_run/config.yaml"
model_path = "./best_run/model.pth"
data_path = "./data"

config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

model = ResNetSimCLR(**config["model"]).to(device)
model.load_state_dict(torch.load(model_path))

lr = 0.001
num_epoch = 90
batch_size = 512
num_classes = 10
weight_decay = 1e-6

transform = transforms.Compose([
    transforms.ToTensor(),
])

linear = Linear_Model(512, num_classes)  # Linear_Model or MLP
linear.to(device)

optimizer = optim.Adam(linear.parameters(), lr=lr, weight_decay=weight_decay)
コード例 #9
0

config_path = "./best_run/config.yaml"
model_path = "./best_run/model.pth"
data_path = "./data"

config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
model = ResNetSimCLR(**config["model"]).to(device)

# x = torch.randn(2, 3, 32, 32).to(device)
# print(model(x))
# exit()

state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

lr = 0.001
num_epoch = 200
batch_size = 512
num_classes = 10

transform = transforms.Compose(
	[transforms.ToTensor(),]) # may need to rewrite

trainset = torchvision.datasets.CIFAR10(root = data_path, train = True, download = True, transform = transform)
testset = torchvision.datasets.CIFAR10(root = data_path, train = False, download = True, transform = transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 3)
testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle = False, num_workers = 3)
コード例 #10
0
ファイル: finetune.py プロジェクト: cloughurd/SimCLR
def run(config):
    model = ResNetSimCLR('resnet50', config.out_dim)
    model.load_state_dict(
        torch.load(config.model_file, map_location=config.device))
    model = model.to(config.device)
    clf = nn.Linear(2048, 196)

    train_data_dir = os.path.join(config.data_root, 'cars_train/')
    train_annos = os.path.join(config.data_root, 'devkit/cars_train_annos.mat')
    valid_data_dir = os.path.join(config.data_root, 'cars_test/')
    valid_annos = os.path.join(config.data_root,
                               'devkit/cars_test_annos_withlabels.mat')

    if config.encodings_file_prefix:
        train_dataset = EncodedStanfordCarsDataset(
            train_annos, config.encodings_file_prefix + '-train_encodings.pt')
        train_loader = DataLoader(train_dataset,
                                  batch_size=config.encodings_batch_size,
                                  shuffle=True)

        valid_dataset = EncodedStanfordCarsDataset(
            valid_annos, config.encodings_file_prefix + '-valid_encodings.pt')
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=config.encodings_batch_size)

        tmp_clf = nn.Linear(2048, 196)
        clf_solver = CESolver(tmp_clf,
                              train_loader,
                              valid_loader,
                              config.save_root,
                              name=config.name + '-clf',
                              device=config.device)
        clf_solver.train(config.encodings_num_epochs)
        clf_filename = os.path.join(config.save_root, f'{config.name}-clf.pth')
        clf.load_state_dict(
            torch.load(clf_filename, map_location=config.device))

    full = FineTuner(model, clf)

    t = T.Compose([
        T.Resize(512),
        T.CenterCrop(512),
        T.ToTensor(),
        T.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x)
    ])
    train_dataset = StanfordCarsMini(train_annos, train_data_dir, t)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)

    valid_dataset = CarsDataset(valid_annos, valid_data_dir, t)
    valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size)

    solver = CESolver(full,
                      train_loader,
                      valid_loader,
                      config.save_root,
                      name=config.name,
                      device=config.device)
    # full = full.to(config.device)
    # print(solver.validate(full, nn.CrossEntropyLoss()))
    solver.train(config.num_epochs)
コード例 #11
0
def main():
    args = parser.parse_args()
    assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
    # check if gpu training is available
    if not args.disable_cuda and torch.cuda.is_available():
        print("Using GPU")
        args.device = torch.device('cuda')
        cudnn.deterministic = True
        cudnn.benchmark = True
    else:
        args.device = torch.device('cpu')
        args.gpu_index = -1

    dataset = ContrastiveLearningDataset(args.data)

    train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True,
                                               worker_init_fn=worker_init_fn)

    if args.dataset_name == "mnist":
        in_channels = 1
    else:
        in_channels = 3

    model = ResNetSimCLR(base_model=args.arch,
                         out_dim=args.out_dim,
                         in_channels=in_channels)

    if args.model_path is not None:
        checkpoint = torch.load(args.model_path, map_location=args.device)
        model.load_state_dict(checkpoint['state_dict'])

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)

    head_dataset = HeadDataset(args.data)
    train_head_dataset = head_dataset.get_dataset(args.dataset_name,
                                                  train=True,
                                                  split="train")
    test_head_dataset = head_dataset.get_dataset(args.dataset_name,
                                                 train=False,
                                                 split="test")

    args.num_classes = head_dataset.get_num_classes(args.dataset_name)

    train_head_loader = torch.utils.data.DataLoader(train_head_dataset,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True,
                                                    drop_last=True)

    test_head_loader = torch.utils.data.DataLoader(test_head_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    #  It’s a no-op if the 'gpu_index' argument is a negative integer or None.
    with torch.cuda.device(args.gpu_index):
        if not (args.head_only or args.tsne_only):
            simclr = SimCLR(model=model,
                            optimizer=optimizer,
                            scheduler=scheduler,
                            args=args)
            model = simclr.train(train_loader)
        model.eval()
        if not args.tsne_only:
            headsimclr = SimCLRHead(model=model, args=args)
            headsimclr.train(train_head_loader, test_head_loader)

        tsne_plot = TSNE_project(model, test_head_loader, args)