def show_wcgan_data(): im_index = range(9) latent_dim = 20 data_list = os.listdir('coedatas') data = [] for path in data_list: data.append(data_read('coedatas/' + path)) FloatTensor = torch.FloatTensor LongTensor = torch.LongTensor generator = G_D_Module.GeneratorWCGAN(latent_dim, 5, (1, 32, 32)) generator.load_state_dict(torch.load('GANParameters/WCGAN/generator.pt')) noise = FloatTensor(np.random.normal(0, 1, (len(data)**2, latent_dim))) single_list = list(range(len(data))) label = LongTensor(single_list * len(data)) gen_imags = generator(noise, label) gen_imags = gen_imags.cpu() for i in range(gen_imags.size(0)): plt.subplot(len(data), len(data), i + 1) plt.axis('off') plt.contourf(gen_imags[i][0].detach().numpy()) plt.savefig('caches/gen.jpg', bbox_inches='tight') # plt.show() plt.close() for i in range(len(data)): for j in range(len(data)): index = random.randint(0, data[j].shape[0] - 1) plt.subplot(len(data), len(data), i * len(data) + j + 1) plt.axis('off') plt.contourf(data[j][index]) plt.savefig('caches/real.jpg', bbox_inches='tight') # plt.show() plt.close()
def ex_wcgan(): data_sets = ReWrite.load_data_in_seq(source_files) data_sets = ReWrite.MyDataSet(data_sets) data_loader = DataLoader( data_sets, batch_size=512, shuffle=True, ) latent_dim = 20 generator = G_D_Module.GeneratorWCGAN(latent_dim, opt.n_classes, img_shape) # latent_dim should be 20 discriminator = G_D_Module.DiscriminatorWCGAN(opt.n_classes, img_shape) TrainFunction.train_wcgan(generator, discriminator, data_loader, opt.n_epochs, opt.lr, opt.b1, opt.b2, opt.latent_dim, opt.n_classes, cuda, fist_train=False)