Esempio n. 1
0
def train(config, inputs, args):
    gan = setup_gan(config, inputs, args)
    trainable_gan = hg.TrainableGAN(gan,
                                    save_file=save_file,
                                    devices=args.devices,
                                    backend_name=args.backend)
    sampler = TrainingVideoFrameSampler(gan)
    gan.selected_sampler = ""
    samples = 0

    #metrics = [batch_accuracy(gan.inputs.x, gan.uniform_sample), batch_diversity(gan.uniform_sample)]
    #sum_metrics = [0 for metric in metrics]
    for i in range(args.steps):
        trainable_gan.step()

        if args.action == 'train' and i % args.save_every == 0 and i > 0:
            print("saving " + save_file)
            trainable_gan.save()

        if i % args.sample_every == 0:
            sample_file = "samples/" + args.config + "/%06d.png" % (samples)
            os.makedirs(os.path.expanduser(os.path.dirname(sample_file)),
                        exist_ok=True)
            samples += 1
            sampler.sample(sample_file, args.save_samples)

        #if i > args.steps * 9.0/10:
        #    for k, metric in enumerate(gan.session.run(metrics)):
        #        print("Metric "+str(k)+" "+str(metric))
        #        sum_metrics[k] += metric

    return []  #sum_metrics
Esempio n. 2
0
    def train(config, inputs, args):
        gan = setup_gan(config, inputs, args)
        trainable_gan = hg.TrainableGAN(gan,
                                        save_file=save_file,
                                        devices=args.devices,
                                        backend_name=args.backend)

        trainers = []

        x_0 = gan.inputs.next()
        z_0 = gan.latent.sample()

        ax_sum = 0
        ag_sum = 0
        diversity = 0.00001
        dlog = 0
        last_i = 0
        steps = 0

        while (True):
            steps += 1
            if steps > args.steps and args.steps != -1:
                break
            trainable_gan.step()

            if args.action == 'train' and steps % args.save_every == 0 and steps > 0:
                print("saving " + save_file)
                trainable_gan.save()

            if steps % args.sample_every == 0:
                x_val = gan.inputs.next()
                q, a = torch.split(x_val, x_val[0].shape[-1] // 2, dim=-1)
                g = gan.generator.forward(gan.latent.sample(),
                                          context={"q": q})
                print("Query:")
                print(
                    inputs.textdata.sample_output(q[0].cpu().detach().numpy()))
                print("X answer:")
                print(
                    inputs.textdata.sample_output(a[0].cpu().detach().numpy()))
                print("G answer:")
                g_output = inputs.textdata.sample_output(
                    g[0].cpu().detach().numpy())
                g_output = re.sub(r'[\x00-\x1f\x7f-\x9f]', '�', g_output)
                print(g_output)

        if args.config is None:
            with open("sequence-results-10k.csv", "a") as myfile:
                myfile.write(config_name + "," + str(ax_sum) + "," +
                             str(ag_sum) + "," + str(ax_sum + ag_sum) + "," +
                             str(ax_sum * ag_sum) + "," + str(dlog) + "," +
                             str(diversity) + "," + str(ax_sum * ag_sum *
                                                        (1 / diversity)) +
                             "," + str(last_i) + "\n")
Esempio n. 3
0
    def sample_forever(self):
        self.gan.inputs.next()
        self.lazy_create()
        self.trainable_gan = hg.TrainableGAN(self.gan, save_file = self.save_file, devices = self.devices, backend_name = self.args.backend)

        if self.trainable_gan.load():
            print("Model loaded")
        else:
            print("Could not load save")
            return
        steps = 0
        while not self.gan.destroy and (steps <= self.args.steps or self.args.steps == -1):
            self.trainable_gan.sample(self.sampler, self.sample_path)
            steps += 1
Esempio n. 4
0
    def train(self):
        i=0
        if(self.args.ipython):
            import fcntl
            fd = sys.stdin.fileno()
            fl = fcntl.fcntl(fd, fcntl.F_GETFL)
            fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)

        self.gan = hg.GAN(config=self.gan_config, inputs=self.create_input(), device=self.args.parameter_server_device)
        self.gan.cli = self #TODO remove this link
        self.gan.inputs.next()
        self.lazy_create()

        self.trainable_gan = hg.TrainableGAN(self.gan, save_file = self.save_file, devices = self.devices, backend_name = self.args.backend)

        if self.trainable_gan.load():
            print("Model loaded")
        else:
            print("Initializing new model")

        self.trainable_gan.sample(self.sampler, self.sample_path)

        while((self.steps < self.total_steps or self.total_steps == -1) and not self.gan.destroy):
            self.step()
            if self.should_sample:
                self.should_sample = False
                self.sample(False)

            if (self.args.save_every != None and
                self.args.save_every != -1 and
                self.args.save_every > 0 and
                self.steps % self.args.save_every == 0):
                print(" |= Saving network")
                self.trainable_gan.save()
                self.create_path(self.advSavePath+'advSave.txt')
                if os.path.isfile(self.advSavePath+'advSave.txt'):
                    with open(self.advSavePath+'advSave.txt', 'w') as the_file:
                        the_file.write(str(self.samples)+"\n")
            if self.args.ipython:
                self.check_stdin()
        print("Done training model.  Saving")
        self.trainable_gan.save()
        print("============================")
        print("HyperGAN model trained")
        print("============================")
Esempio n. 5
0
def train(config, args):
    gan = setup_gan(config, inputs, args)
    trainable_gan = hg.TrainableGAN(gan, save_file = save_file, devices = args.devices, backend_name = args.backend)
    test_batches = []
    accuracy = 0

    for i in range(args.steps):
        trainable_gan.step()

        if i == args.steps-1 or i % args.sample_every == 0 and i > 0:
            correct_prediction = 0
            total = 0
            for (x,y) in gan.inputs.testdata():
                prediction = gan.generator(x)
                correct_prediction += (torch.argmax(prediction,1) == torch.argmax(y,1)).sum()
                total += y.shape[0]
            accuracy = (float(correct_prediction) / total)*100
            print(config_name)
            print("accuracy: ", accuracy)

    return accuracy
Esempio n. 6
0
def train(config, inputs, args):
    gan = setup_gan(config, inputs, args)
    gan.name = config_name
    trainable_gan = hg.TrainableGAN(gan,
                                    save_file=save_file,
                                    devices=args.devices,
                                    backend_name=args.backend)
    gan.selected_sampler = ""
    sampler = Sampler(gan)
    samples = 0

    for i in range(args.steps):
        trainable_gan.step()

        if args.action == 'train' and i % args.save_every == 0 and i > 0:
            print("saving " + save_file)
            trainable_gan.save()

        if i % args.sample_every == 0:
            sample_file = "samples/" + config_name + "/%06d.png" % (samples)
            os.makedirs(os.path.expanduser(os.path.dirname(sample_file)),
                        exist_ok=True)
            samples += 1
            sampler.sample(sample_file, args.save_samples)
Esempio n. 7
0
def train(config, args):
    title = "[hypergan] 2d-test " + config_filename
    GlobalViewer.set_options(enabled=args.viewer, title=title, viewer_size=1)
    print("ARGS", args)

    gan = hg.GAN(config,
                 inputs=Custom2DInputDistribution(
                     {"batch_size": args.batch_size}))
    trainable_gan = hg.TrainableGAN(gan,
                                    devices=args.devices,
                                    backend_name=args.backend)
    gan.name = config_filename
    if gan.config.use_latent:
        accuracy_x_to_g = lambda: distribution_accuracy(
            gan.inputs.next(1), gan.generator(gan.latent.next()))
        accuracy_g_to_x = lambda: distribution_accuracy(
            gan.generator(gan.latent.next()), gan.inputs.next(1))
    else:
        if gan.config.ali:
            accuracy_x_to_g = lambda: distribution_accuracy(
                gan.inputs.next(1),
                gan.generator(gan.encoder(gan.inputs.next())))
            accuracy_g_to_x = lambda: distribution_accuracy(
                gan.generator(gan.encoder(gan.inputs.next())),
                gan.inputs.next(1))
        else:
            accuracy_x_to_g = lambda: distribution_accuracy(
                gan.inputs.next(1), gan.generator(gan.inputs.next()))
            accuracy_g_to_x = lambda: distribution_accuracy(
                gan.generator(gan.inputs.next()), gan.inputs.next(1))

    sampler = Custom2DSampler(gan)
    gan.selected_sampler = sampler

    samples = 0
    steps = args.steps
    sample_file = "samples/" + config_filename + "/000000.png"
    os.makedirs(os.path.expanduser(os.path.dirname(sample_file)),
                exist_ok=True)
    sampler.sample(sample_file, args.save_samples)

    metrics = [accuracy_x_to_g, accuracy_g_to_x]
    sum_metrics = [0 for metric in metrics]
    broken = False
    for i in range(steps):
        if broken:
            break
        trainable_gan.step()

        if args.viewer and i % args.sample_every == 0:
            samples += 1
            print("Sampling " + str(samples))
            sample_file = "samples/" + config_filename + "/%06d.png" % (
                samples)
            sampler.sample(sample_file, args.save_samples)

        if i % 100 == 0:
            for k, metric in enumerate(metrics):
                _metric = metric().cpu().detach().numpy()
                sum_metrics[k] += _metric
                if not np.isfinite(_metric):
                    broken = True
                    break

    return sum_metrics