Пример #1
0
def ex_ponodcwcgan():
    data_sets = ReWrite.load_data_in_seq(source_files)
    data_sets = ReWrite.MyDataSet(data_sets)
    data_loader = DataLoader(
        data_sets,
        batch_size=256,
        shuffle=True,
    )
    latent_dim = 100
    generator = G_D_Module.GeneratorPONODCWCGAN(
        latent_dim, opt.n_classes, img_shape)  # latent_dim should be 20
    discriminator = G_D_Module.DiscriminatorPONODCWCGAN(
        opt.n_classes, img_shape)

    TrainFunction.train_ponodcwcgan(generator,
                                    discriminator,
                                    data_loader,
                                    opt.n_epochs,
                                    opt.lr,
                                    opt.b1,
                                    opt.b2,
                                    latent_dim,
                                    opt.n_classes,
                                    cuda,
                                    fist_train=False)
Пример #2
0
def show_ponodcwcgan_data():
    latent_dim = 100
    data_list = os.listdir('coedatas')
    data = []
    for path in data_list:
        data.append(data_read('coedatas/' + path))

    FloatTensor = torch.FloatTensor
    LongTensor = torch.LongTensor
    generator = G_D_Module.GeneratorPONODCWCGAN(latent_dim, 5, (1, 32, 32))
    generator.load_state_dict(
        torch.load('GANParameters/PONODCWCGAN/generator.pt'))

    noise = FloatTensor(np.random.normal(0, 1, (len(data)**2, latent_dim)))
    single_list = list(range(len(data)))
    label = LongTensor(single_list * len(data))
    gen_imags = generator(noise, label)
    gen_imags = gen_imags.cpu()

    for i in range(gen_imags.size(0)):
        plt.subplot(len(data), len(data), i + 1)
        plt.axis('off')
        plt.contourf(gen_imags[i][0].detach().numpy())
    plt.savefig('caches/gen.jpg')
    # plt.show()
    plt.close()

    for i in range(len(data)):
        for j in range(len(data)):
            index = random.randint(0, data[j].shape[0] - 1)
            plt.subplot(len(data), len(data), i * len(data) + j + 1)
            plt.axis('off')
            plt.contourf(data[j][index])
    plt.savefig('caches/real.jpg')
    # plt.show()
    plt.close()