예제 #1
0
def test(args, model):
    if args.dataset == "mnist":
        dataset_f = memory_mnist
    elif args.dataset == "fashion_mnist":
        dataset_f = memory_fashion
    elif args.dataset == "celeba":
        dataset_f = celeba
    else:
        raise ValueError("Unknown dataset:", args.dataset)

    args.delta = float(args.model_path.split(";")[-1].split("_")[0].split("#")[1])
    repr_args = string_args(args)
    f = open(f"./test/ll_per_point_{repr_args}_.txt", "w")
    train_loader, val_loader, train_val_loader, train_labels, val_labels = dataset_f(
        1, args.img_size, args.n_channels, return_y=True
    )
    with torch.no_grad():
        for ind, image in enumerate(train_loader):
            # TODO Rozkminić żeby było bez tego repeat
            image = image.repeat(100, 1, 1, 1)
            image = image.to(device)
            log_p, logdet, _ = model(image)
            for i in range(log_p.shape[0]):
                print(
                    ind,
                    args.delta,
                    log_p[i].item(),
                    logdet[i].item(),
                    train_labels[ind].item(),
                    file=f,
                )
            if ind >= 9999:
                break
    f.close()
def train(args, model, optimizer):
    if args.dataset == "mnist":
        dataset_f = memory_mnist
    elif args.dataset == "fashion_mnist":
        dataset_f = memory_fashion

    repr_args = string_args(args)
    n_bins = 2.0**args.n_bits

    z_sample = []
    z_shapes = calc_z_shapes(args.n_channels, args.img_size, args.n_flow,
                             args.n_block)
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    deltas = create_deltas_sequence(0.1, 0.005)
    args.delta = deltas[0]

    epoch_losses = []
    f_train_loss = open(f"losses/seq_losses_train_{repr_args}_.txt",
                        "w",
                        buffering=1)
    f_test_loss = open(f"losses/seq_losses_test_{repr_args}_.txt",
                       "w",
                       buffering=1)

    with tqdm(range(200)) as pbar:
        for i in pbar:
            args.delta = deltas[i]
            repr_args = string_args(args)
            train_loader, val_loader, train_val_loader = dataset_f(
                args.batch, args.img_size, args.n_channels)
            train_losses = []
            for image in train_loader:
                optimizer.zero_grad()
                image = image.to(device)
                if args.tr_dq:
                    noisy_image += torch.rand_like(image) / n_bins
                noisy_image += torch.randn_like(image) * args.delta
                log_p, logdet, _ = model(noisy_image)
                logdet = logdet.mean()
                loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size,
                                                 n_bins, args.n_channels)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())
            current_train_loss = np.mean(train_losses)
            print(f"{current_train_loss},{args.delta},{i + 1}",
                  file=f_train_loss)
            with torch.no_grad():
                utils.save_image(
                    model.reverse(z_sample).cpu().data,
                    f"sample/seq_sample_{repr_args}_{str(i + 1).zfill(6)}.png",
                    normalize=True,
                    nrow=10,
                    range=(-0.5, 0.5),
                )
                losses = []
                logdets = []
                logps = []
                for image in val_loader:
                    image = image.to(device)
                    noisy_image = image
                    if args.te_dq:
                        noisy_image += torch.rand_like(image) / n_bins
                    if args.te_noise:
                        noisy_image += torch.randn_like(image) * args.delta
                    log_p, logdet, _ = model(noisy_image)
                    logdet = logdet.mean()
                    loss, log_p, log_det = calc_loss(log_p, logdet,
                                                     args.img_size, n_bins,
                                                     args.n_channels)
                    losses.append(loss.item())
                    logdets.append(log_det.item())
                    logps.append(log_p.item())
                pbar.set_description(
                    f"Loss: {np.mean(losses):.5f}; logP: {np.mean(logps):.5f}; logdet: {np.mean(logdets):.5f}; delta: {args.delta:.5f}"
                )
                current_loss = np.mean(losses)
                print(f"{current_loss},{args.delta},{i + 1}", file=f_test_loss)
                epoch_losses.append(current_loss)
                if (i + 1) % 10 == 0:
                    torch.save(
                        model.state_dict(),
                        f"checkpoint/seq_model_{repr_args}_{i + 1}_.pt",
                    )

                f_ll = open(f"ll/seq_ll_{repr_args}_{i + 1}.txt", "w")
                train_loader, val_loader, train_val_loader = dataset_f(
                    args.batch, args.img_size, args.n_channels)
                train_val_loader = iter(train_val_loader)
                for image_val in val_loader:
                    image = image_val
                    image = image.to(device)
                    if args.te_dq:
                        noisy_image += torch.rand_like(image) / n_bins
                    if args.te_noise:
                        noisy_image += torch.randn_like(image) * args.delta
                    log_p_val, logdet_val, _ = model(noisy_image)

                    image = next(train_val_loader)
                    image = image.to(device)
                    if args.te_dq:
                        noisy_image += torch.rand_like(image) / n_bins
                    if args.te_noise:
                        noisy_image += torch.randn_like(image) * args.delta
                    log_p_train_val, logdet_train_val, _ = model(noisy_image)
                    for (
                            lpv,
                            ldv,
                            lptv,
                            ldtv,
                    ) in zip(log_p_val, logdet_val, log_p_train_val,
                             logdet_train_val):
                        print(
                            args.delta,
                            lpv.item(),
                            ldv.item(),
                            lptv.item(),
                            ldtv.item(),
                            file=f_ll,
                        )
                f_ll.close()
    f_train_loss.close()
    f_test_loss.close()
예제 #3
0
                print(
                    ind,
                    args.delta,
                    log_p[i].item(),
                    logdet[i].item(),
                    train_labels[ind].item(),
                    file=f,
                )
            if ind >= 9999:
                break
    f.close()


if __name__ == "__main__":
    args = parser.parse_args()
    print(string_args(args))
    device = args.device

    model_single = Glow(
        args.n_channels,
        args.n_flow,
        args.n_block,
        affine=args.affine,
        conv_lu=not args.no_lu,
    )
    model = model_single
    model.load_state_dict(torch.load(args.model_path))
    model = model.to(device)

    test(args, model)
예제 #4
0
def train(args, model, optimizer):
    if args.dataset == "mnist":
        dataset_f = memory_mnist
    elif args.dataset == "fashion_mnist":
        dataset_f = memory_fashion
    elif args.dataset == "celeba":
        dataset_f = celeba
    elif args.dataset == "ffhq_gan_32":
        dataset_f = ffhq_gan_32
    elif args.dataset == "cifar_horses_40":
        dataset_f = cifar_horses_40
    elif args.dataset == "ffhq_50":
        dataset_f = ffhq_50
    elif args.dataset == "cifar_horses_20":
        dataset_f = cifar_horses_20
    elif args.dataset == "cifar_horses_80":
        dataset_f = cifar_horses_80
    elif args.dataset == "mnist_30":
        dataset_f = mnist_30
    elif args.dataset == "mnist_gan_all":
        dataset_f = mnist_gan_all
    elif args.dataset == "mnist_pad":
        dataset_f = mnist_pad
    elif args.dataset == "cifar_horses_20_top":
        dataset_f = cifar_horses_20_top
    elif args.dataset == "cifar_horses_40_top":
        dataset_f = cifar_horses_40_top
    elif args.dataset == "cifar_horses_20_top_small_lr":
        dataset_f = cifar_horses_20_top_small_lr
    elif args.dataset == "cifar_horses_40_top_small_lr":
        dataset_f = cifar_horses_40_top_small_lr
    elif args.dataset == "arrows_small":
        dataset_f = arrows_small
    elif args.dataset == "arrows_big":
        dataset_f = arrows_big
    elif args.dataset == "cifar_20_picked_inds_2":
        dataset_f = cifar_20_picked_inds_2
    elif args.dataset == "cifar_40_picked_inds_2":
        dataset_f = cifar_40_picked_inds_2
    elif args.dataset == "cifar_40_picked_inds_3":
        dataset_f = cifar_40_picked_inds_3
    elif args.dataset == "cifar_20_picked_inds_3":
        dataset_f = cifar_20_picked_inds_3
    else:
        raise ValueError("Unknown dataset:", args.dataset)

    repr_args = string_args(args)
    n_bins = 2.0**args.n_bits

    z_sample = []
    z_shapes = calc_z_shapes(args.n_channels, args.img_size, args.n_flow,
                             args.n_block)
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    epoch_losses = []
    f_train_loss = open(f"losses/losses_train_{repr_args}_.txt",
                        "a",
                        buffering=1)
    f_test_loss = open(f"losses/losses_test_{repr_args}_.txt",
                       "a",
                       buffering=1)

    last_model_path = f"checkpoint/model_{repr_args}_last_.pt"
    try:
        model.load_state_dict(torch.load(last_model_path))
        model.eval()
        f_epoch = open(f"checkpoint/last_epoch_{repr_args}.txt",
                       "r",
                       buffering=1)
        epoch_n = int(f_epoch.readline().strip())
        f_epoch.close()
    except FileNotFoundError:
        print("Training the model from scratch.")
        epoch_n = 0

    with tqdm(range(epoch_n, args.epochs + epoch_n)) as pbar:
        for i in pbar:
            repr_args = string_args(args)
            train_loader, val_loader, train_val_loader = dataset_f(
                args.batch, args.img_size, args.n_channels)
            train_losses = []
            for image in train_loader:
                if isinstance(image, list):
                    image = image[0]
                optimizer.zero_grad()
                image = image.to(device)
                noisy_image = image
                if args.tr_dq:
                    noisy_image += torch.rand_like(image) / n_bins
                noisy_image += torch.randn_like(image) * args.delta
                log_p, logdet, _ = model(noisy_image)

                logdet = logdet.mean()
                loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size,
                                                 n_bins, args.n_channels)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())
            current_train_loss = np.mean(train_losses)
            print(f"{current_train_loss},{args.delta},{i + 1}",
                  file=f_train_loss)
            with torch.no_grad():
                utils.save_image(
                    model.reverse(z_sample).cpu().data,
                    f"sample/sample_{repr_args}_{str(i + 1).zfill(6)}.png",
                    normalize=True,
                    nrow=10,
                    range=(-0.5, 0.5),
                )
                losses = []
                logdets = []
                logps = []
                for image in val_loader:
                    if isinstance(image, list):
                        image = image[0]
                    image = image.to(device)
                    log_p, logdet, _ = model(image)
                    logdet = logdet.mean()
                    loss, log_p, log_det = calc_loss(log_p, logdet,
                                                     args.img_size, n_bins,
                                                     args.n_channels)
                    losses.append(loss.item())
                    logdets.append(log_det.item())
                    logps.append(log_p.item())
                pbar.set_description(
                    f"Loss: {np.mean(losses):.5f}; logP: {np.mean(logps):.5f}; logdet: {np.mean(logdets):.5f}; delta: {args.delta:.5f}"
                )
                current_loss = np.mean(losses)
                print(f"{current_loss},{args.delta},{i + 1}", file=f_test_loss)
                epoch_losses.append(current_loss)
                # early stopping
                if len(epoch_losses) >= 20 and epoch_losses[-20] < min(
                        epoch_losses[-19:]):
                    break
                '''
                too much space
                if (i + 1) % 5 == 0:
                    torch.save(
                        model.state_dict(), f"checkpoint/model_{repr_args}_{i + 1}_.pt"
                    )
                '''
                torch.save(model.state_dict(), last_model_path)
                f_epoch = open(f"checkpoint/last_epoch_{repr_args}.txt",
                               "w",
                               buffering=1)
                f_epoch.write(str(i + 1))
                f_epoch.close()

                f_ll = open(f"ll/ll_{repr_args}_{i + 1}.txt", "w")
                train_loader, val_loader, train_val_loader = dataset_f(
                    args.batch, args.img_size, args.n_channels)
                train_val_loader = iter(train_val_loader)
                for image_val in val_loader:
                    image = image_val
                    if isinstance(image, list):
                        image = image[0]
                    image = image.to(device)
                    log_p_val, logdet_val, _ = model(image)

                    image = next(train_val_loader)
                    if isinstance(image, list):
                        image = image[0]
                    image = image.to(device)
                    log_p_train_val, logdet_train_val, _ = model(image)

                    for (
                            lpv,
                            ldv,
                            lptv,
                            ldtv,
                    ) in zip(log_p_val, logdet_val, log_p_train_val,
                             logdet_train_val):
                        print(
                            args.delta,
                            lpv.item(),
                            ldv.item(),
                            lptv.item(),
                            ldtv.item(),
                            file=f_ll,
                        )
                f_ll.close()
    f_train_loss.close()
    f_test_loss.close()