def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    root = "./datasets"

    train_sampler = None

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            root, split="unlabeled", download=True, transform=TransformsSimCLR()
        )
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            root, download=True, transform=TransformsSimCLR()
        )
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    model, optimizer, scheduler = load_model(args, train_loader)

    tb_dir = os.path.join(args.out_dir, _run.experiment_info["name"])
    os.makedirs(tb_dir)
    writer = SummaryWriter(log_dir=tb_dir)

    mask = mask_correlated_samples(args)
    criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]['lr']
        loss_epoch = train(args, train_loader, model, criterion, optimizer, writer)

        if scheduler:
            scheduler.step()

        if epoch % 10 == 0:
            save_model(args, model, optimizer)

        writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
        writer.add_scalar("Misc/learning_rate", lr, epoch)
        print(
            f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
        )
        args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
Ejemplo n.º 2
0
def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    root = "./datasets"
    simclr_model = load_model(args, reload_model=True)
    simclr_model = simclr_model.to(args.device)
    simclr_model.eval()

    ## Logistic Regression
    n_classes = 10  # stl-10
    model = LogisticRegression(simclr_model.n_features, n_classes)
    model = model.to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    train_dataset = torchvision.datasets.STL10(
        root,
        split="train",
        download=True,
        transform=torchvision.transforms.ToTensor())

    test_dataset = torchvision.datasets.STL10(
        root,
        split="test",
        download=True,
        transform=torchvision.transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.logistic_batch_size,
        drop_last=True,
        num_workers=args.workers,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.logistic_batch_size,
        drop_last=True,
        num_workers=args.workers,
    )

    for epoch in range(args.logistic_epochs):
        loss_epoch, accuracy_epoch = train(args, train_loader, simclr_model,
                                           model, criterion, optimizer)
        print(
            f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}"
        )

    # final testing
    loss_epoch, accuracy_epoch = test(args, test_loader, simclr_model, model,
                                      criterion, optimizer)
    print(
        f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}"
    )
Ejemplo n.º 3
0
def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    root = "./datasets"
    model = load_model(args)
    model = model.to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)  # TODO: LARS

    train_sampler = None
    train_dataset = torchvision.datasets.STL10(root,
                                               split="unlabeled",
                                               download=True,
                                               transform=TransformsSimCLR())

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    tb_dir = os.path.join(args.out_dir, _run.experiment_info["name"])
    os.makedirs(tb_dir)
    writer = SummaryWriter(log_dir=tb_dir)

    mask = mask_correlated_samples(args)
    criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        loss_epoch = train(args, train_loader, model, criterion, optimizer,
                           writer)

        writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
        if epoch % 10 == 0:
            save_model(args, model, optimizer)

        print(
            f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}"
        )
        args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
Ejemplo n.º 4
0
def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    root = "./datasets"
    train_sampler = None
    valid_sampler = None

    if args.dataset == "STL10":
        dataset = torchvision.datasets.STL10(
            root,
            split="train",
            download=True,
            transform=TransformsSimCLR(size=224).test_transform,
        )
        test_dataset = torchvision.datasets.STL10(
            root,
            split="test",
            download=True,
            transform=TransformsSimCLR(size=224).test_transform,
        )
    elif args.dataset == "CIFAR10":
        dataset = torchvision.datasets.CIFAR10(
            root,
            train=True,
            download=True,
            transform=TransformsSimCLR(size=224).test_transform,
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root,
            train=False,
            download=True,
            transform=TransformsSimCLR(size=224).test_transform,
        )
    elif args.dataset == "MATEK":
        dataset, train_sampler, valid_sampler = MatekDataset(
            root=root,
            transforms=TransformsSimCLR(size=128).test_transform,
            test_size=args.test_size).get_dataset()
    elif args.dataset == "JURKAT":
        dataset, train_sampler, valid_sampler = JurkatDataset(
            root=root,
            transforms=TransformsSimCLR(size=64).test_transform,
            test_size=args.test_size).get_dataset()
    elif args.dataset == "PLASMODIUM":
        dataset, train_sampler, valid_sampler = PlasmodiumDataset(
            root=root,
            transforms=TransformsSimCLR(size=128).test_transform,
            test_size=args.test_size).get_dataset()
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.logistic_batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    test_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.logistic_batch_size,
        shuffle=(valid_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=valid_sampler,
    )

    simclr_model, _, _ = load_model(args, train_loader, reload_model=True)
    simclr_model = simclr_model.to(args.device)
    simclr_model.eval()

    # Logistic Regression
    n_classes = args.n_classes  # stl-10
    model = LogisticRegression(simclr_model.n_features, n_classes)
    model = model.to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    print("### Creating features from pre-trained context model ###")
    (train_X, train_y, test_X, test_y) = get_features(simclr_model,
                                                      train_loader,
                                                      test_loader, args.device)

    arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
        train_X, train_y, test_X, test_y, args.logistic_batch_size)

    for epoch in range(args.logistic_epochs):
        loss_epoch, accuracy_epoch = train(args, arr_train_loader,
                                           simclr_model, model, criterion,
                                           optimizer)
        print(
            f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}"
        )

    # final testing
    loss_epoch, accuracy_epoch, report = test(args, arr_test_loader,
                                              simclr_model, model, criterion,
                                              optimizer)
    print(
        f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}"
    )

    print(report)
Ejemplo n.º 5
0
def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()

    train_sampler = None

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            root=args.dataset_root,
            split="unlabeled",
            download=True,
            transform=TransformsSimCLR(size=96))
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            root=args.dataset_root,
            download=True,
            transform=TransformsSimCLR(size=32))
    elif args.dataset == "MATEK":
        train_dataset, _ = MatekDataset(
            root=args.dataset_root,
            transforms=TransformsSimCLR(size=128)).get_dataset()
    elif args.dataset == "JURKAT":
        train_dataset, _ = JurkatDataset(
            root=args.dataset_root,
            transforms=TransformsSimCLR(size=64)).get_dataset()
    elif args.dataset == "PLASMODIUM":
        train_dataset, _ = PlasmodiumDataset(
            root=args.dataset_root,
            transforms=TransformsSimCLR(size=128)).get_dataset()
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    model, optimizer, scheduler = load_model(args, train_loader)

    print(f"Using {args.n_gpu}'s")
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
        model = convert_model(model)
        model = model.to(args.device)

    print(model)

    tb_dir = os.path.join(args.out_dir, _run.experiment_info["name"])
    os.makedirs(tb_dir)
    writer = SummaryWriter(log_dir=tb_dir)

    criterion = NT_Xent(args.batch_size, args.temperature, args.device)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]['lr']
        loss_epoch = train(args, train_loader, model, criterion, optimizer,
                           writer)

        if scheduler:
            scheduler.step()

        if epoch % 10 == 0:
            save_model(args, model, optimizer)

        writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
        writer.add_scalar("Misc/learning_rate", lr, epoch)
        print(
            f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
        )
        args.current_epoch += 1

    save_model(args, model, optimizer)