batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 drop_last=True)

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        #        encoder = BERTEncoder()
        encoder = BERTEncoder(opt.embed_dim)
        onehot_encoder = OneHotClassEmbedding(train_dataset.n_classes)
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    model_G = CDCGAN_G(z_dim=opt.z_dim,
                       embed_dim=opt.embed_dim,
                       n_filters=opt.n_filters)
    model_D = CDCGAN_D(embed_dim=opt.embed_dim, n_filters=opt.n_filters)
    #    model_G.weight_init(mean=0.0, std=0.02)
    #    model_D.weight_init(mean=0.0, std=0.02)
Exemplo n.º 2
0
def main(args=None):
    if args:
        opt = parser.parse_args(args)
    else:
        opt = parser.parse_args()

    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32DatasetDiscrete(
            train=not opt.train_on_val,
            max_size=1 if opt.debug else opt.train_size)
        val_dataset = Imagenet32DatasetDiscrete(
            train=0,
            max_size=1 if opt.debug else opt.val_size,
            start_idx=opt.val_start_idx)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        encoder = BERTEncoder()
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    # generative_model = ConditionalPixelCNNpp(embd_size=encoder.embed_size, img_shape=train_dataset.image_shape,
    #                                          nr_resnet=opt.n_resnet, nr_filters=opt.n_filters,
    #                                          nr_logistic_mix=3 if train_dataset.image_shape[0] == 1 else 10)

    generative_model = FlowPlusPlus(
        scales=[(0, 4), (2, 3)],
        # in_shape=(3, 32, 32),
        in_shape=train_dataset.image_shape,
        mid_channels=opt.n_filters,
        num_blocks=opt.num_blocks,
        num_dequant_blocks=opt.num_dequant_blocks,
        num_components=opt.num_components,
        use_attn=opt.use_attn,
        drop_prob=opt.drop_prob,
        condition_embd_size=encoder.embed_size)

    generative_model = generative_model.to(device)
    encoder = encoder.to(device)
    print("Models loaded on device")

    # Configure data loader

    print("dataloaders loaded")
    # Optimizers
    # optimizer = torch.optim.Adam(generative_model.parameters(), lr=opt.lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=opt.lr_decay)
    param_groups = util.get_param_groups(generative_model,
                                         opt.lr_decay,
                                         norm_suffix='weight_g')
    optimizer = torch.optim.Adam(param_groups, lr=opt.lr)
    warm_up = opt.warm_up * opt.batch_size
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lambda s: min(1., s / warm_up))
    # create output directory

    os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
    os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"))

    global global_step
    global_step = 0

    # ----------
    #  Training
    # ----------
    if opt.train:
        train(model=generative_model,
              embedder=encoder,
              optimizer=optimizer,
              scheduler=scheduler,
              train_loader=train_dataloader,
              val_loader=val_dataloader,
              opt=opt,
              writer=writer,
              device=device)
    else:
        assert opt.model_checkpoint is not None, 'no model checkpoint specified'
        print("Loading model from state dict...")
        load_model(opt.model_checkpoint, generative_model)
        print("Model loaded.")
        sample_images_full(generative_model,
                           encoder,
                           opt.output_dir,
                           dataloader=val_dataloader,
                           device=device)
        eval(model=generative_model,
             embedder=encoder,
             test_loader=val_dataloader,
             opt=opt,
             writer=writer,
             device=device)
                                                 batch_size=opt.batch_size,
                                                 shuffle=True,
                                                 drop_last=True)

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        encoder = BERTEncoder()
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    model_G = CDCGAN_G(z_dim=opt.z_dim, embed_dim=768, n_filters=opt.n_filters)
    model_D = CDCGAN_D(n_filters=opt.n_filters, embed_dim=768)
    model_G.weight_init(mean=0.0, std=0.02)
    model_D.weight_init(mean=0.0, std=0.02)

    if cuda:
        model_G = model_G.cuda()
        model_D = model_D.cuda()
        encoder = encoder.cuda()
Exemplo n.º 4
0
def main(args=None):
    if args:
        opt = parser.parse_args(args)
    else:
        opt = parser.parse_args()

    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32Dataset(
            train=not opt.train_on_val,
            max_size=1 if opt.debug else opt.train_size)
        val_dataset = Imagenet32Dataset(
            train=0,
            max_size=1 if opt.debug else opt.val_size,
            start_idx=opt.val_start_idx)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        encoder = BERTEncoder()
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    generative_model = ConditionalPixelCNNpp(
        embd_size=encoder.embed_size,
        img_shape=train_dataset.image_shape,
        nr_resnet=opt.n_resnet,
        nr_filters=opt.n_filters,
        nr_logistic_mix=3 if train_dataset.image_shape[0] == 1 else 10)

    generative_model = generative_model.to(device)
    encoder = encoder.to(device)
    print("Models loaded on device")

    # Configure data loader

    print("dataloaders loaded")
    # Optimizers
    optimizer = torch.optim.Adam(generative_model.parameters(), lr=opt.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=opt.lr_decay)
    # create output directory

    os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
    os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"))

    # ----------
    #  Training
    # ----------
    if opt.train:
        train(model=generative_model,
              embedder=encoder,
              optimizer=optimizer,
              scheduler=scheduler,
              train_loader=train_dataloader,
              val_loader=val_dataloader,
              opt=opt,
              writer=writer,
              device=device)
    else:
        assert opt.model_checkpoint is not None, 'no model checkpoint specified'
        print("Loading model from state dict...")
        load_model(opt.model_checkpoint, generative_model)
        print("Model loaded.")
        sample_images_full(generative_model,
                           encoder,
                           opt.output_dir,
                           dataloader=val_dataloader,
                           device=device)
        eval(model=generative_model,
             embedder=encoder,
             test_loader=val_dataloader,
             opt=opt,
             writer=writer,
             device=device)