def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    root = "./datasets"

    train_sampler = None

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            root, split="unlabeled", download=True, transform=TransformsSimCLR()
        )
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            root, download=True, transform=TransformsSimCLR()
        )
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    model, optimizer, scheduler = load_model(args, train_loader)

    tb_dir = os.path.join(args.out_dir, _run.experiment_info["name"])
    os.makedirs(tb_dir)
    writer = SummaryWriter(log_dir=tb_dir)

    mask = mask_correlated_samples(args)
    criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]['lr']
        loss_epoch = train(args, train_loader, model, criterion, optimizer, writer)

        if scheduler:
            scheduler.step()

        if epoch % 10 == 0:
            save_model(args, model, optimizer)

        writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
        writer.add_scalar("Misc/learning_rate", lr, epoch)
        print(
            f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
        )
        args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
Ejemplo n.º 2
0
def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")

    root = "./datasets"
    model = load_model(args)
    model = model.to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)  # TODO: LARS

    train_sampler = None
    train_dataset = torchvision.datasets.STL10(root,
                                               split="unlabeled",
                                               download=True,
                                               transform=TransformsSimCLR())

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    tb_dir = os.path.join(args.out_dir, _run.experiment_info["name"])
    os.makedirs(tb_dir)
    writer = SummaryWriter(log_dir=tb_dir)

    mask = mask_correlated_samples(args)
    criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        loss_epoch = train(args, train_loader, model, criterion, optimizer,
                           writer)

        writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
        if epoch % 10 == 0:
            save_model(args, model, optimizer)

        print(
            f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}"
        )
        args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
Ejemplo n.º 3
0
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

model, optimizer, scheduler = load_model(args, train_loader)

mask = mask_correlated_samples(args)
criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)

args.global_step = 0
args.current_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]['lr']
    loss_epoch = train(args, train_loader, model, criterion, optimizer)

    if scheduler:
        scheduler.step()

    if epoch % 1 == 0:
        save_model(args, model, optimizer)

    #writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
    #writer.add_scalar("Misc/learning_rate", lr, epoch)
Ejemplo n.º 4
0
def main(_run, _log):
    args = argparse.Namespace(**_run.config)
    args = post_config_hook(args, _run)

    args.device = torch.device(
        "cuda:0" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()

    train_sampler = None

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            root=args.dataset_root,
            split="unlabeled",
            download=True,
            transform=TransformsSimCLR(size=96))
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            root=args.dataset_root,
            download=True,
            transform=TransformsSimCLR(size=32))
    elif args.dataset == "MATEK":
        train_dataset, _ = MatekDataset(
            root=args.dataset_root,
            transforms=TransformsSimCLR(size=128)).get_dataset()
    elif args.dataset == "JURKAT":
        train_dataset, _ = JurkatDataset(
            root=args.dataset_root,
            transforms=TransformsSimCLR(size=64)).get_dataset()
    elif args.dataset == "PLASMODIUM":
        train_dataset, _ = PlasmodiumDataset(
            root=args.dataset_root,
            transforms=TransformsSimCLR(size=128)).get_dataset()
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    model, optimizer, scheduler = load_model(args, train_loader)

    print(f"Using {args.n_gpu}'s")
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
        model = convert_model(model)
        model = model.to(args.device)

    print(model)

    tb_dir = os.path.join(args.out_dir, _run.experiment_info["name"])
    os.makedirs(tb_dir)
    writer = SummaryWriter(log_dir=tb_dir)

    criterion = NT_Xent(args.batch_size, args.temperature, args.device)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]['lr']
        loss_epoch = train(args, train_loader, model, criterion, optimizer,
                           writer)

        if scheduler:
            scheduler.step()

        if epoch % 10 == 0:
            save_model(args, model, optimizer)

        writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
        writer.add_scalar("Misc/learning_rate", lr, epoch)
        print(
            f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
        )
        args.current_epoch += 1

    save_model(args, model, optimizer)
Ejemplo n.º 5
0
def main(gpu, args):
    rank = args.nr * args.gpus + gpu

    if args.nodes > 1:
        dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
        torch.cuda.set_device(gpu)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            args.dataset_dir,
            split="unlabeled",
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    else:
        raise NotImplementedError

    if args.nodes > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=args.world_size,
            rank=rank,
            shuffle=True)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    # initialize ResNet
    encoder = get_resnet(args.resnet, pretrained=False)
    n_features = encoder.fc.in_features  # get dimensions of fc layer

    # initialize model
    model = SimCLR(args, encoder, n_features)
    if args.reload:
        model_fp = os.path.join(args.model_path,
                                "checkpoint_{}.tar".format(args.epoch_num))
        model.load_state_dict(
            torch.load(model_fp, map_location=args.device.type))
    model = model.to(args.device)

    # optimizer / loss
    optimizer, scheduler = load_optimizer(args, model)
    criterion = NT_Xent(args.batch_size, args.temperature, args.device,
                        args.world_size)

    # DDP / DP
    if args.dataparallel:
        model = convert_model(model)
        model = DataParallel(model)
    else:
        if args.nodes > 1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = DDP(model, device_ids=[gpu])

    model = model.to(args.device)

    writer = None
    if args.nr == 0:
        writer = SummaryWriter()

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]["lr"]
        loss_epoch = train(args, train_loader, model, criterion, optimizer,
                           writer)

        if args.nr == 0 and scheduler:
            scheduler.step()

        if args.nr == 0 and epoch % 10 == 0:
            save_model(args, model, optimizer)

        if args.nr == 0:
            writer.add_scalar("Loss/train", loss_epoch / len(train_loader),
                              epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            print(
                f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
            )
            args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
Ejemplo n.º 6
0
def main(gpu, args):
    rank = args.nr * args.gpus + gpu

    if args.nodes > 1:
        dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
        torch.cuda.set_device(gpu)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.dataset == "STL10":
        train_dataset = torchvision.datasets.STL10(
            args.dataset_dir,
            split="unlabeled",
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    elif args.dataset == "CIFAR10":
        train_dataset = torchvision.datasets.CIFAR10(
            args.dataset_dir,
            download=True,
            transform=TransformsSimCLR(size=args.image_size),
        )
    else:
        raise NotImplementedError

    if args.nodes > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=args.world_size,
            rank=rank,
            shuffle=True)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.workers,
        sampler=train_sampler,
    )

    # initialize ResNet
    encoder = get_resnet(args.resnet, pretrained=False)
    n_features = encoder.fc.in_features  # get dimensions of fc layer

    # initialize model
    model = SimCLR(args, encoder, n_features)
    if args.reload:
        model_fp = os.path.join(args.model_path,
                                "checkpoint_{}.tar".format(args.epoch_num))
        print(model_fp)
        model.load_state_dict(
            torch.load(model_fp, map_location=args.device.type))
    model = model.to(args.device)

    # optimizer / loss
    optimizer, scheduler = load_optimizer(args, model)
    criterion = NT_Xent(args.batch_size, args.temperature, args.device,
                        args.world_size)

    # DDP / DP
    if args.dataparallel:
        model = convert_model(model)
        model = DataParallel(model)
    else:
        if args.nodes > 1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = DDP(model, device_ids=[gpu])

    model = model.to(args.device)

    writer = None
    if args.nr == 0:
        writer = SummaryWriter()

    #added by @IvanKruzhilov
    decoder = Decoder(3, 3, args.image_size)
    optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=0.001)
    #decoder.load_state_dict(torch.load('save/decoder_my_algorithm_augmented.pt'))
    decoder = decoder.to(args.device)

    args.global_step = 0
    args.current_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        lr = optimizer.param_groups[0]["lr"]
        scatter_radius = 0.2
        random_fake = None  #set in train fucntion now

        loss_epoch, loss_epoch_decoder, penalty_epoch = \
            train(args, train_loader, model, decoder, criterion, optimizer, \
            optimizer_decoder, writer, random_fake, scatter_radius)

        loss_mean, bce_mean = train_autoencoder(model, decoder, train_loader, None, \
                                                optimizer_decoder, freeze_encoder=True)

        if args.nr == 0 and scheduler:
            scheduler.step()

        if args.nr == 0 and epoch % 5 == 0:
            save_model(args, model, optimizer)
            torch.save(
                decoder.state_dict(),
                os.path.join(args.model_path, 'decoder{0}.pt'.format(epoch)))

        if epoch % 10 == 0:
            decoder = Decoder(3, 3, args.image_size)
            optimizer_decoder = torch.optim.Adam(decoder.parameters(),
                                                 lr=0.001)
            decoder = decoder.to(args.device)

        if args.nr == 0:
            writer.add_scalar("Loss/train", loss_epoch / len(train_loader),
                              epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            mean_loss = loss_epoch / len(train_loader)
            mean_loss_decoder = loss_epoch_decoder / len(train_loader)
            mean_penalty = penalty_epoch / len(train_loader)
            print(
                f"Epoch [{epoch}/{args.epochs}]\t Loss: {mean_loss}\t decoder loss: {mean_loss_decoder}\t \
                penalty: {mean_penalty}\t lr: {round(lr, 5)}")
            print('loss: ', loss_mean, 'mse: ', bce_mean)
            args.current_epoch += 1

    ## end training
    save_model(args, model, optimizer)
Ejemplo n.º 7
0
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.model = SimCLR(hparams.projection_dim)
        self.criterion = NT_Xent(hparams.batch_size, hparams.temperature)