Beispiel #1
0
def main():
    utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    print(args)

    # Basic Setup
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.cuda.set_device(2)
    cudnn.benchmark = True
    cudnn.enabled = True

    n_channels = 3
    n_bins = 2.**args.n_bits
    approx_samples = 4

    # Define model
    model_single = Network(n_channels,
                           args.n_flow,
                           args.n_block,
                           n_bins,
                           affine=args.affine,
                           conv_lu=not args.no_lu)
    model = nn.DataParallel(model_single, device_ids=[2, 3])
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
    dataset = iter(sample_cifar10(args.batch, args.img_size))

    # Sample generated images
    z_sample = []
    z_shapes = calc_z_shapes(n_channels, args.img_size, args.n_flow,
                             args.n_block)
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    with tqdm(range(args.iter)) as pbar:
        for i in pbar:
            # Training procedure
            model.train()

            # Get a random minibatch from the search queue with replacement
            input, _ = next(dataset)
            input = Variable(input,
                             requires_grad=False).cuda(non_blocking=True)
            input = input.repeat(approx_samples, 1, 1, 1)

            log_p, logdet, _ = model(input + torch.rand_like(input) / n_bins)

            loss, _, _ = likelihood_loss(log_p, logdet, args.img_size, n_bins)

            loss_variance = likelihood_loss_variance(log_p, logdet,
                                                     args.img_size, n_bins,
                                                     approx_samples)

            loss = loss + loss_variance

            # Optimize model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description("Loss: {}".format(loss.item()))

            # Save generated samples
            if i % 100 == 0:
                with torch.no_grad():
                    tvutils.save_image(
                        model_single.reverse(z_sample).cpu().data,
                        "{}/samples/{}.png".format(args.save,
                                                   str(i + 1).zfill(6)),
                        normalize=False,
                        nrow=10,
                    )

            # Save checkpoint
            if i % 1000 == 0:
                model_single.genotype()
                torch.save(
                    model.state_dict(),
                    "{}/checkpoint/model_{}.pt".format(args.save,
                                                       str(i + 1).zfill(6)))

            # Save latest weights
            utils.save(model, os.path.join(args.save, 'latest_weights.pt'))
def main():
    utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    print(args)

    seed = random.randint(1, 100000000)
    print(seed)

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    n_channels = 3
    n_bins = 2.**args.n_bits

    # Define model and loss criteria
    model = SearchNetwork(n_channels,
                          args.n_flow,
                          args.n_block,
                          n_bins,
                          affine=args.affine,
                          conv_lu=not args.no_lu)
    model = nn.DataParallel(model, [args.gpu])
    model.load_state_dict(
        torch.load("architecture.pt", map_location="cuda:{}".format(args.gpu)))
    model = model.module
    genotype = model.sample_architecture()

    with open(args.save + '/genotype.pkl', 'wb') as fp:
        pickle.dump(genotype, fp)

    model_single = EnsembleNetwork(n_channels,
                                   args.n_flow,
                                   args.n_block,
                                   n_bins,
                                   genotype,
                                   affine=args.affine,
                                   conv_lu=not args.no_lu)
    model = model_single
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)

    dataset = iter(sample_cifar10(args.batch, args.img_size))

    # Sample generated images
    z_sample = []
    z_shapes = calc_z_shapes(n_channels, args.img_size, args.n_flow,
                             args.n_block)
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    with tqdm(range(args.iter)) as pbar:
        for i in pbar:
            # Training procedure
            model.train()

            # Get a random minibatch from the search queue with replacement
            input, _ = next(dataset)
            input = Variable(input,
                             requires_grad=False).cuda(non_blocking=True)

            log_p, logdet, _ = model(input + torch.rand_like(input) / n_bins)

            logdet = logdet.mean()
            loss, _, _ = likelihood_loss(log_p, logdet, args.img_size, n_bins)

            # Optimize model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_description("Loss: {}".format(loss.item()))

            # Save generated samples
            if i % 100 == 0:
                with torch.no_grad():
                    tvutils.save_image(
                        model_single.reverse(z_sample).cpu().data,
                        "{}/samples/{}.png".format(args.save,
                                                   str(i + 1).zfill(6)),
                        normalize=False,
                        nrow=10,
                    )

            # Save checkpoint
            if i % 1000 == 0:
                utils.save(model, os.path.join(args.save, 'latest_weights.pt'))
def train(args, model, optimizer):
    if args.dataset == "mnist":
        dataset_f = memory_mnist
    elif args.dataset == "fashion_mnist":
        dataset_f = memory_fashion

    repr_args = string_args(args)
    n_bins = 2.0**args.n_bits

    z_sample = []
    z_shapes = calc_z_shapes(args.n_channels, args.img_size, args.n_flow,
                             args.n_block)
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    deltas = create_deltas_sequence(0.1, 0.005)
    args.delta = deltas[0]

    epoch_losses = []
    f_train_loss = open(f"losses/seq_losses_train_{repr_args}_.txt",
                        "w",
                        buffering=1)
    f_test_loss = open(f"losses/seq_losses_test_{repr_args}_.txt",
                       "w",
                       buffering=1)

    with tqdm(range(200)) as pbar:
        for i in pbar:
            args.delta = deltas[i]
            repr_args = string_args(args)
            train_loader, val_loader, train_val_loader = dataset_f(
                args.batch, args.img_size, args.n_channels)
            train_losses = []
            for image in train_loader:
                optimizer.zero_grad()
                image = image.to(device)
                if args.tr_dq:
                    noisy_image += torch.rand_like(image) / n_bins
                noisy_image += torch.randn_like(image) * args.delta
                log_p, logdet, _ = model(noisy_image)
                logdet = logdet.mean()
                loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size,
                                                 n_bins, args.n_channels)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())
            current_train_loss = np.mean(train_losses)
            print(f"{current_train_loss},{args.delta},{i + 1}",
                  file=f_train_loss)
            with torch.no_grad():
                utils.save_image(
                    model.reverse(z_sample).cpu().data,
                    f"sample/seq_sample_{repr_args}_{str(i + 1).zfill(6)}.png",
                    normalize=True,
                    nrow=10,
                    range=(-0.5, 0.5),
                )
                losses = []
                logdets = []
                logps = []
                for image in val_loader:
                    image = image.to(device)
                    noisy_image = image
                    if args.te_dq:
                        noisy_image += torch.rand_like(image) / n_bins
                    if args.te_noise:
                        noisy_image += torch.randn_like(image) * args.delta
                    log_p, logdet, _ = model(noisy_image)
                    logdet = logdet.mean()
                    loss, log_p, log_det = calc_loss(log_p, logdet,
                                                     args.img_size, n_bins,
                                                     args.n_channels)
                    losses.append(loss.item())
                    logdets.append(log_det.item())
                    logps.append(log_p.item())
                pbar.set_description(
                    f"Loss: {np.mean(losses):.5f}; logP: {np.mean(logps):.5f}; logdet: {np.mean(logdets):.5f}; delta: {args.delta:.5f}"
                )
                current_loss = np.mean(losses)
                print(f"{current_loss},{args.delta},{i + 1}", file=f_test_loss)
                epoch_losses.append(current_loss)
                if (i + 1) % 10 == 0:
                    torch.save(
                        model.state_dict(),
                        f"checkpoint/seq_model_{repr_args}_{i + 1}_.pt",
                    )

                f_ll = open(f"ll/seq_ll_{repr_args}_{i + 1}.txt", "w")
                train_loader, val_loader, train_val_loader = dataset_f(
                    args.batch, args.img_size, args.n_channels)
                train_val_loader = iter(train_val_loader)
                for image_val in val_loader:
                    image = image_val
                    image = image.to(device)
                    if args.te_dq:
                        noisy_image += torch.rand_like(image) / n_bins
                    if args.te_noise:
                        noisy_image += torch.randn_like(image) * args.delta
                    log_p_val, logdet_val, _ = model(noisy_image)

                    image = next(train_val_loader)
                    image = image.to(device)
                    if args.te_dq:
                        noisy_image += torch.rand_like(image) / n_bins
                    if args.te_noise:
                        noisy_image += torch.randn_like(image) * args.delta
                    log_p_train_val, logdet_train_val, _ = model(noisy_image)
                    for (
                            lpv,
                            ldv,
                            lptv,
                            ldtv,
                    ) in zip(log_p_val, logdet_val, log_p_train_val,
                             logdet_train_val):
                        print(
                            args.delta,
                            lpv.item(),
                            ldv.item(),
                            lptv.item(),
                            ldtv.item(),
                            file=f_ll,
                        )
                f_ll.close()
    f_train_loss.close()
    f_test_loss.close()
Beispiel #4
0
def train(args, model, optimizer):
    if args.dataset == "mnist":
        dataset_f = memory_mnist
    elif args.dataset == "fashion_mnist":
        dataset_f = memory_fashion
    elif args.dataset == "celeba":
        dataset_f = celeba
    elif args.dataset == "ffhq_gan_32":
        dataset_f = ffhq_gan_32
    elif args.dataset == "cifar_horses_40":
        dataset_f = cifar_horses_40
    elif args.dataset == "ffhq_50":
        dataset_f = ffhq_50
    elif args.dataset == "cifar_horses_20":
        dataset_f = cifar_horses_20
    elif args.dataset == "cifar_horses_80":
        dataset_f = cifar_horses_80
    elif args.dataset == "mnist_30":
        dataset_f = mnist_30
    elif args.dataset == "mnist_gan_all":
        dataset_f = mnist_gan_all
    elif args.dataset == "mnist_pad":
        dataset_f = mnist_pad
    elif args.dataset == "cifar_horses_20_top":
        dataset_f = cifar_horses_20_top
    elif args.dataset == "cifar_horses_40_top":
        dataset_f = cifar_horses_40_top
    elif args.dataset == "cifar_horses_20_top_small_lr":
        dataset_f = cifar_horses_20_top_small_lr
    elif args.dataset == "cifar_horses_40_top_small_lr":
        dataset_f = cifar_horses_40_top_small_lr
    elif args.dataset == "arrows_small":
        dataset_f = arrows_small
    elif args.dataset == "arrows_big":
        dataset_f = arrows_big
    elif args.dataset == "cifar_20_picked_inds_2":
        dataset_f = cifar_20_picked_inds_2
    elif args.dataset == "cifar_40_picked_inds_2":
        dataset_f = cifar_40_picked_inds_2
    elif args.dataset == "cifar_40_picked_inds_3":
        dataset_f = cifar_40_picked_inds_3
    elif args.dataset == "cifar_20_picked_inds_3":
        dataset_f = cifar_20_picked_inds_3
    else:
        raise ValueError("Unknown dataset:", args.dataset)

    repr_args = string_args(args)
    n_bins = 2.0**args.n_bits

    z_sample = []
    z_shapes = calc_z_shapes(args.n_channels, args.img_size, args.n_flow,
                             args.n_block)
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    epoch_losses = []
    f_train_loss = open(f"losses/losses_train_{repr_args}_.txt",
                        "a",
                        buffering=1)
    f_test_loss = open(f"losses/losses_test_{repr_args}_.txt",
                       "a",
                       buffering=1)

    last_model_path = f"checkpoint/model_{repr_args}_last_.pt"
    try:
        model.load_state_dict(torch.load(last_model_path))
        model.eval()
        f_epoch = open(f"checkpoint/last_epoch_{repr_args}.txt",
                       "r",
                       buffering=1)
        epoch_n = int(f_epoch.readline().strip())
        f_epoch.close()
    except FileNotFoundError:
        print("Training the model from scratch.")
        epoch_n = 0

    with tqdm(range(epoch_n, args.epochs + epoch_n)) as pbar:
        for i in pbar:
            repr_args = string_args(args)
            train_loader, val_loader, train_val_loader = dataset_f(
                args.batch, args.img_size, args.n_channels)
            train_losses = []
            for image in train_loader:
                if isinstance(image, list):
                    image = image[0]
                optimizer.zero_grad()
                image = image.to(device)
                noisy_image = image
                if args.tr_dq:
                    noisy_image += torch.rand_like(image) / n_bins
                noisy_image += torch.randn_like(image) * args.delta
                log_p, logdet, _ = model(noisy_image)

                logdet = logdet.mean()
                loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size,
                                                 n_bins, args.n_channels)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())
            current_train_loss = np.mean(train_losses)
            print(f"{current_train_loss},{args.delta},{i + 1}",
                  file=f_train_loss)
            with torch.no_grad():
                utils.save_image(
                    model.reverse(z_sample).cpu().data,
                    f"sample/sample_{repr_args}_{str(i + 1).zfill(6)}.png",
                    normalize=True,
                    nrow=10,
                    range=(-0.5, 0.5),
                )
                losses = []
                logdets = []
                logps = []
                for image in val_loader:
                    if isinstance(image, list):
                        image = image[0]
                    image = image.to(device)
                    log_p, logdet, _ = model(image)
                    logdet = logdet.mean()
                    loss, log_p, log_det = calc_loss(log_p, logdet,
                                                     args.img_size, n_bins,
                                                     args.n_channels)
                    losses.append(loss.item())
                    logdets.append(log_det.item())
                    logps.append(log_p.item())
                pbar.set_description(
                    f"Loss: {np.mean(losses):.5f}; logP: {np.mean(logps):.5f}; logdet: {np.mean(logdets):.5f}; delta: {args.delta:.5f}"
                )
                current_loss = np.mean(losses)
                print(f"{current_loss},{args.delta},{i + 1}", file=f_test_loss)
                epoch_losses.append(current_loss)
                # early stopping
                if len(epoch_losses) >= 20 and epoch_losses[-20] < min(
                        epoch_losses[-19:]):
                    break
                '''
                too much space
                if (i + 1) % 5 == 0:
                    torch.save(
                        model.state_dict(), f"checkpoint/model_{repr_args}_{i + 1}_.pt"
                    )
                '''
                torch.save(model.state_dict(), last_model_path)
                f_epoch = open(f"checkpoint/last_epoch_{repr_args}.txt",
                               "w",
                               buffering=1)
                f_epoch.write(str(i + 1))
                f_epoch.close()

                f_ll = open(f"ll/ll_{repr_args}_{i + 1}.txt", "w")
                train_loader, val_loader, train_val_loader = dataset_f(
                    args.batch, args.img_size, args.n_channels)
                train_val_loader = iter(train_val_loader)
                for image_val in val_loader:
                    image = image_val
                    if isinstance(image, list):
                        image = image[0]
                    image = image.to(device)
                    log_p_val, logdet_val, _ = model(image)

                    image = next(train_val_loader)
                    if isinstance(image, list):
                        image = image[0]
                    image = image.to(device)
                    log_p_train_val, logdet_train_val, _ = model(image)

                    for (
                            lpv,
                            ldv,
                            lptv,
                            ldtv,
                    ) in zip(log_p_val, logdet_val, log_p_train_val,
                             logdet_train_val):
                        print(
                            args.delta,
                            lpv.item(),
                            ldv.item(),
                            lptv.item(),
                            ldtv.item(),
                            file=f_ll,
                        )
                f_ll.close()
    f_train_loss.close()
    f_test_loss.close()