def ex_ponodcwcgan(): data_sets = ReWrite.load_data_in_seq(source_files) data_sets = ReWrite.MyDataSet(data_sets) data_loader = DataLoader( data_sets, batch_size=256, shuffle=True, ) latent_dim = 100 generator = G_D_Module.GeneratorPONODCWCGAN( latent_dim, opt.n_classes, img_shape) # latent_dim should be 20 discriminator = G_D_Module.DiscriminatorPONODCWCGAN( opt.n_classes, img_shape) TrainFunction.train_ponodcwcgan(generator, discriminator, data_loader, opt.n_epochs, opt.lr, opt.b1, opt.b2, latent_dim, opt.n_classes, cuda, fist_train=False)
def show_ponodcwcgan_data(): latent_dim = 100 data_list = os.listdir('coedatas') data = [] for path in data_list: data.append(data_read('coedatas/' + path)) FloatTensor = torch.FloatTensor LongTensor = torch.LongTensor generator = G_D_Module.GeneratorPONODCWCGAN(latent_dim, 5, (1, 32, 32)) generator.load_state_dict( torch.load('GANParameters/PONODCWCGAN/generator.pt')) noise = FloatTensor(np.random.normal(0, 1, (len(data)**2, latent_dim))) single_list = list(range(len(data))) label = LongTensor(single_list * len(data)) gen_imags = generator(noise, label) gen_imags = gen_imags.cpu() for i in range(gen_imags.size(0)): plt.subplot(len(data), len(data), i + 1) plt.axis('off') plt.contourf(gen_imags[i][0].detach().numpy()) plt.savefig('caches/gen.jpg') # plt.show() plt.close() for i in range(len(data)): for j in range(len(data)): index = random.randint(0, data[j].shape[0] - 1) plt.subplot(len(data), len(data), i * len(data) + j + 1) plt.axis('off') plt.contourf(data[j][index]) plt.savefig('caches/real.jpg') # plt.show() plt.close()