Beispiel #1
0
    def __init__(self, args):
        super().__init__()

        self.hparams = args

        # initialize ResNet
        self.encoder = get_resnet(self.hparams.resnet, pretrained=False)
        self.n_features = self.encoder.fc.in_features  # get dimensions of fc layer
        self.model = SimCLR(self.encoder, self.hparams.projection_dim,
                            self.n_features)
        self.criterion = NT_Xent(self.hparams.batch_size,
                                 self.hparams.temperature,
                                 world_size=1)
Beispiel #2
0
    def __init__(self, args):
        super().__init__()

        # self.hparams = args
        self.args = args

        # initialize ResNet
        self.encoder = get_resnet(self.args.resnet,
                                  pretrained=self.args.pretrain)
        self.n_features = self.encoder.fc.in_features  # get dimensions of fc layer
        self.model = SimCLR(self.encoder, self.args.h_dim,
                            self.args.projection_dim, self.n_features,
                            self.args.n_classes)
        self.test_outputs = np.array([])
        self.criterion = NT_Xent(self.args.batch_size,
                                 self.args.temperature,
                                 world_size=1)
        train_dataset,
        batch_size=args.logistic_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.logistic_batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.workers,
    )

    encoder = get_resnet(args.resnet, pretrained=False)
    n_features = encoder.fc.in_features  # get dimensions of fc layer

    # load pre-trained model from checkpoint
    simclr_model = SimCLR(encoder, args.projection_dim, n_features)
    model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.epoch_num))
    simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
    simclr_model = simclr_model.to(args.device)
    simclr_model.eval()

    ## Logistic Regression
    n_classes = 10  # CIFAR-10 / STL-10
    model = LogisticRegression(simclr_model.n_features, n_classes)
    model = model.to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
Beispiel #4
0
def main(gpu, args):
    rank = args.nr * args.gpus + gpu

    if args.nodes > 1:
        dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
        torch.cuda.set_device(gpu)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            args.dataset_dir,
            split="unlabeled",
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    else:
        raise NotImplementedError

    if args.nodes > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=args.world_size,
            rank=rank,
            shuffle=True)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    # initialize ResNet
    encoder = get_resnet(args.resnet, pretrained=False)
    n_features = encoder.fc.in_features  # get dimensions of fc layer

    # initialize model
    model = SimCLR(encoder, args.projection_dim, n_features)
    if args.reload:
        model_fp = os.path.join(args.model_path,
                                "checkpoint_{}.tar".format(args.epoch_num))
        model.load_state_dict(
            torch.load(model_fp, map_location=args.device.type))
    model = model.to(args.device)

    # optimizer / loss
    optimizer, scheduler = load_optimizer(args, model)
    criterion = NT_Xent(args.batch_size, args.temperature, args.world_size)

    # DDP / DP
    if args.dataparallel:
        model = convert_model(model)
        model = DataParallel(model)
    else:
        if args.nodes > 1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = DDP(model, device_ids=[gpu])

    model = model.to(args.device)

    writer = None
    if args.nr == 0:
        writer = SummaryWriter()

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]["lr"]
        loss_epoch = train(args, train_loader, model, criterion, optimizer,
                           writer)

        if args.nr == 0 and scheduler:
            scheduler.step()

        if args.nr == 0 and epoch % 10 == 0:
            save_model(args, model, optimizer)

        if args.nr == 0:
            writer.add_scalar("Loss/train", loss_epoch / len(train_loader),
                              epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            print(
                f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
            )
            args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
def main(gpu, args):
    rank = args.nr * args.gpus + gpu

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    tfms = A.Compose(
        [A.Resize(512, 512),
         A.RandomCrop(384, 384),
         ToColorTensor()])
    if torch.cuda.device_count() == 1:
        if torch.cuda.get_device_name() == 'GeForce MX130':
            tfms = A.Compose([A.Resize(256, 256), A.RandomCrop(192, 192)])

    # dataset = RandomSlicerDataset(
    #     args.datasets_root, img_tfm(tfms),
    #     args.slices_per_scan, args.inter_slice_distance
    # )
    dataset = PatchedDataset(args.datasets_root, target_size=(128, 128))

    loader = DataLoader(dataset, batch_size=args.scans_per_batch, shuffle=True)

    # initialize ResNet
    encoder = get_resnet(args.resnet, pretrained=False)
    #
    # override input layer to make it monochrome
    encoder.conv1 = Conv2d(1,
                           encoder.conv1.out_channels,
                           kernel_size=7,
                           stride=2,
                           padding=3,
                           bias=False)
    kaiming_normal_(encoder.conv1.weight, mode='fan_out', nonlinearity='relu')
    #
    n_features = encoder.fc.in_features  # get dimensions of fc layer

    # initialize model
    model = SimCLR(args, encoder, n_features)
    if args.reload or args.start_epoch:
        epoch_n = args.start_epoch if args.start_epoch else args.epoch_num
        model_fp = os.path.join(args.model_path, f'checkpoint_{epoch_n}.tar')
        model.load_state_dict(
            torch.load(model_fp, map_location=args.device.type))
        print(f'Loaded from epoch #{epoch_n}')

    model = model.to(args.device)

    # optimizer / loss
    optimizer, scheduler = load_optimizer(args, model)
    criterion = NT_Xent(
        args.slices_per_scan if args.single_scan_loss else
        args.scans_per_batch * args.slices_per_scan, args.temperature,
        args.device, args.world_size)

    if args.dataparallel:
        model = convert_model(model)
        model = DataParallel(model)

    model = model.to(args.device)

    writer = None
    if args.nr == 0:
        run_log = os.path.join(args.log_dir,
                               datetime.now().strftime('%Y%m%d-%H%M%S'))
        writer = SummaryWriter(run_log)

    args.global_step = 0
    args.current_epoch = 0
    dataset.set_tst_mode()
    validation_loss_trc = [validate(loader, model, visualize=True)]
    dataset.set_trn_mode()
    print('validation loss before train:', validation_loss_trc[0])
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]["lr"]
        t_start = datetime.now()
        loss_epoch = train(args, loader, model, criterion, optimizer, writer)
        t_end = datetime.now()
        train_time = (t_end - t_start)

        dataset.set_tst_mode()
        validation_loss_trc.append(
            validate(loader, model, epoch == args.epochs - 1))
        dataset.set_trn_mode()

        if args.nr == 0 and scheduler:
            scheduler.step()

        if args.nr == 0 and epoch % 10 == 0:
            save_model(args, model, optimizer)

        if args.nr == 0:
            writer.add_scalar("Loss/train", loss_epoch / len(loader), epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            print(
                f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(loader)}\t lr: {round(lr, 5)}"
            )
            print('validation loss:', validation_loss_trc[-1],
                  '\t\ttrain_time:', train_time)
            sleep(0.2)
            args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
    #
    plt.plot(validation_loss_trc)
    plt.grid(True)
    plt.title('validation loss trace')
    plt.show()
Beispiel #6
0
def main(gpu, args):
    rank = args.nr * args.gpus + gpu

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            args.dataset_dir,
            split="unlabeled",
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=args.workers)

    # initialize ResNet
    encoder = get_resnet(args.resnet, pretrained=False)
    n_features = encoder.fc.in_features  # get dimensions of fc layer

    # initialize model
    model = SimCLR(args, encoder, n_features)
    if args.reload or args.start_epoch:
        epoch_n = args.start_epoch if args.start_epoch else args.epoch_num
        model_fp = os.path.join(args.model_path, f'checkpoint_{epoch_n}.tar')
        model.load_state_dict(
            torch.load(model_fp, map_location=args.device.type))
        print(f'Loaded from epoch #{epoch_n}')
    model = model.to(args.device)

    # optimizer / loss
    optimizer, scheduler = load_optimizer(args, model)
    criterion = NT_Xent(args.batch_size, args.temperature, args.device,
                        args.world_size)

    if args.dataparallel:
        model = convert_model(model)
        model = DataParallel(model)

    model = model.to(args.device)

    writer = None
    if args.nr == 0:
        run_log = os.path.join(args.log_dir,
                               datetime.now().strftime('%Y%m%d-%H%M%S'))
        writer = SummaryWriter(run_log)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]["lr"]
        loss_epoch = train(args, train_loader, model, criterion, optimizer,
                           writer)

        if args.nr == 0 and scheduler:
            scheduler.step()

        if args.nr == 0 and epoch % 10 == 0:
            save_model(args, model, optimizer)

        if args.nr == 0:
            writer.add_scalar("Loss/train", loss_epoch / len(train_loader),
                              epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            print(
                f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
            )
            args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)