def train(config, inputs, args): gan = setup_gan(config, inputs, args) trainable_gan = hg.TrainableGAN(gan, save_file=save_file, devices=args.devices, backend_name=args.backend) sampler = TrainingVideoFrameSampler(gan) gan.selected_sampler = "" samples = 0 #metrics = [batch_accuracy(gan.inputs.x, gan.uniform_sample), batch_diversity(gan.uniform_sample)] #sum_metrics = [0 for metric in metrics] for i in range(args.steps): trainable_gan.step() if args.action == 'train' and i % args.save_every == 0 and i > 0: print("saving " + save_file) trainable_gan.save() if i % args.sample_every == 0: sample_file = "samples/" + args.config + "/%06d.png" % (samples) os.makedirs(os.path.expanduser(os.path.dirname(sample_file)), exist_ok=True) samples += 1 sampler.sample(sample_file, args.save_samples) #if i > args.steps * 9.0/10: # for k, metric in enumerate(gan.session.run(metrics)): # print("Metric "+str(k)+" "+str(metric)) # sum_metrics[k] += metric return [] #sum_metrics
def train(config, inputs, args): gan = setup_gan(config, inputs, args) trainable_gan = hg.TrainableGAN(gan, save_file=save_file, devices=args.devices, backend_name=args.backend) trainers = [] x_0 = gan.inputs.next() z_0 = gan.latent.sample() ax_sum = 0 ag_sum = 0 diversity = 0.00001 dlog = 0 last_i = 0 steps = 0 while (True): steps += 1 if steps > args.steps and args.steps != -1: break trainable_gan.step() if args.action == 'train' and steps % args.save_every == 0 and steps > 0: print("saving " + save_file) trainable_gan.save() if steps % args.sample_every == 0: x_val = gan.inputs.next() q, a = torch.split(x_val, x_val[0].shape[-1] // 2, dim=-1) g = gan.generator.forward(gan.latent.sample(), context={"q": q}) print("Query:") print( inputs.textdata.sample_output(q[0].cpu().detach().numpy())) print("X answer:") print( inputs.textdata.sample_output(a[0].cpu().detach().numpy())) print("G answer:") g_output = inputs.textdata.sample_output( g[0].cpu().detach().numpy()) g_output = re.sub(r'[\x00-\x1f\x7f-\x9f]', '�', g_output) print(g_output) if args.config is None: with open("sequence-results-10k.csv", "a") as myfile: myfile.write(config_name + "," + str(ax_sum) + "," + str(ag_sum) + "," + str(ax_sum + ag_sum) + "," + str(ax_sum * ag_sum) + "," + str(dlog) + "," + str(diversity) + "," + str(ax_sum * ag_sum * (1 / diversity)) + "," + str(last_i) + "\n")
def sample_forever(self): self.gan.inputs.next() self.lazy_create() self.trainable_gan = hg.TrainableGAN(self.gan, save_file = self.save_file, devices = self.devices, backend_name = self.args.backend) if self.trainable_gan.load(): print("Model loaded") else: print("Could not load save") return steps = 0 while not self.gan.destroy and (steps <= self.args.steps or self.args.steps == -1): self.trainable_gan.sample(self.sampler, self.sample_path) steps += 1
def train(self): i=0 if(self.args.ipython): import fcntl fd = sys.stdin.fileno() fl = fcntl.fcntl(fd, fcntl.F_GETFL) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) self.gan = hg.GAN(config=self.gan_config, inputs=self.create_input(), device=self.args.parameter_server_device) self.gan.cli = self #TODO remove this link self.gan.inputs.next() self.lazy_create() self.trainable_gan = hg.TrainableGAN(self.gan, save_file = self.save_file, devices = self.devices, backend_name = self.args.backend) if self.trainable_gan.load(): print("Model loaded") else: print("Initializing new model") self.trainable_gan.sample(self.sampler, self.sample_path) while((self.steps < self.total_steps or self.total_steps == -1) and not self.gan.destroy): self.step() if self.should_sample: self.should_sample = False self.sample(False) if (self.args.save_every != None and self.args.save_every != -1 and self.args.save_every > 0 and self.steps % self.args.save_every == 0): print(" |= Saving network") self.trainable_gan.save() self.create_path(self.advSavePath+'advSave.txt') if os.path.isfile(self.advSavePath+'advSave.txt'): with open(self.advSavePath+'advSave.txt', 'w') as the_file: the_file.write(str(self.samples)+"\n") if self.args.ipython: self.check_stdin() print("Done training model. Saving") self.trainable_gan.save() print("============================") print("HyperGAN model trained") print("============================")
def train(config, args): gan = setup_gan(config, inputs, args) trainable_gan = hg.TrainableGAN(gan, save_file = save_file, devices = args.devices, backend_name = args.backend) test_batches = [] accuracy = 0 for i in range(args.steps): trainable_gan.step() if i == args.steps-1 or i % args.sample_every == 0 and i > 0: correct_prediction = 0 total = 0 for (x,y) in gan.inputs.testdata(): prediction = gan.generator(x) correct_prediction += (torch.argmax(prediction,1) == torch.argmax(y,1)).sum() total += y.shape[0] accuracy = (float(correct_prediction) / total)*100 print(config_name) print("accuracy: ", accuracy) return accuracy
def train(config, inputs, args): gan = setup_gan(config, inputs, args) gan.name = config_name trainable_gan = hg.TrainableGAN(gan, save_file=save_file, devices=args.devices, backend_name=args.backend) gan.selected_sampler = "" sampler = Sampler(gan) samples = 0 for i in range(args.steps): trainable_gan.step() if args.action == 'train' and i % args.save_every == 0 and i > 0: print("saving " + save_file) trainable_gan.save() if i % args.sample_every == 0: sample_file = "samples/" + config_name + "/%06d.png" % (samples) os.makedirs(os.path.expanduser(os.path.dirname(sample_file)), exist_ok=True) samples += 1 sampler.sample(sample_file, args.save_samples)
def train(config, args): title = "[hypergan] 2d-test " + config_filename GlobalViewer.set_options(enabled=args.viewer, title=title, viewer_size=1) print("ARGS", args) gan = hg.GAN(config, inputs=Custom2DInputDistribution( {"batch_size": args.batch_size})) trainable_gan = hg.TrainableGAN(gan, devices=args.devices, backend_name=args.backend) gan.name = config_filename if gan.config.use_latent: accuracy_x_to_g = lambda: distribution_accuracy( gan.inputs.next(1), gan.generator(gan.latent.next())) accuracy_g_to_x = lambda: distribution_accuracy( gan.generator(gan.latent.next()), gan.inputs.next(1)) else: if gan.config.ali: accuracy_x_to_g = lambda: distribution_accuracy( gan.inputs.next(1), gan.generator(gan.encoder(gan.inputs.next()))) accuracy_g_to_x = lambda: distribution_accuracy( gan.generator(gan.encoder(gan.inputs.next())), gan.inputs.next(1)) else: accuracy_x_to_g = lambda: distribution_accuracy( gan.inputs.next(1), gan.generator(gan.inputs.next())) accuracy_g_to_x = lambda: distribution_accuracy( gan.generator(gan.inputs.next()), gan.inputs.next(1)) sampler = Custom2DSampler(gan) gan.selected_sampler = sampler samples = 0 steps = args.steps sample_file = "samples/" + config_filename + "/000000.png" os.makedirs(os.path.expanduser(os.path.dirname(sample_file)), exist_ok=True) sampler.sample(sample_file, args.save_samples) metrics = [accuracy_x_to_g, accuracy_g_to_x] sum_metrics = [0 for metric in metrics] broken = False for i in range(steps): if broken: break trainable_gan.step() if args.viewer and i % args.sample_every == 0: samples += 1 print("Sampling " + str(samples)) sample_file = "samples/" + config_filename + "/%06d.png" % ( samples) sampler.sample(sample_file, args.save_samples) if i % 100 == 0: for k, metric in enumerate(metrics): _metric = metric().cpu().detach().numpy() sum_metrics[k] += _metric if not np.isfinite(_metric): broken = True break return sum_metrics