コード例 #1
0
    def plot(self, image, filename, save_sample, regularize=True):
        """ Plot an image."""
        if regularize:
            image = np.minimum(image, 1)
            image = np.maximum(image, -1)
        image = np.squeeze(image)
        if np.shape(image)[2] == 4:
            fmt = "RGBA"
        else:
            fmt = "RGB"
        # Scale to 0..255.
        imin, imax = -1.0, 1.0
        image = (image - imin) * 255. / (imax - imin) + .5
        image = image.astype(np.uint8)
        if save_sample:
            try:
                Image.fromarray(image, fmt).save(filename)
            except Exception as e:
                print(
                    "Warning: could not sample to ", filename,
                    ".  Please check permissions and make sure the path exists"
                )
                print(e)

        GlobalViewer.update(self.gan, image)
        return image
コード例 #2
0
ファイル: 2d-alignment.py プロジェクト: szad670401/HyperGAN
def train(config, args):
    title = "[hypergan] 2d-test " + config_filename
    GlobalViewer.set_options(enabled=args.viewer, title=title, viewer_size=1)
    print("ARGS", args)

    gan = hg.GAN(config,
                 inputs=Custom2DInputDistribution(
                     {"batch_size": args.batch_size}))
    gan.name = config_filename
    if gan.config.use_latent:
        accuracy_x_to_g = lambda: distribution_accuracy(
            gan.inputs.next(1), gan.generator(gan.latent.next()))
        accuracy_g_to_x = lambda: distribution_accuracy(
            gan.generator(gan.latent.next()), gan.inputs.next(1))
    else:
        accuracy_x_to_g = lambda: distribution_accuracy(
            gan.inputs.next(1), gan.generator(gan.inputs.next()))
        accuracy_g_to_x = lambda: distribution_accuracy(
            gan.generator(gan.inputs.next()), gan.inputs.next(1))

    sampler = Custom2DSampler(gan)
    gan.selected_sampler = sampler

    samples = 0
    steps = args.steps
    sample_file = "samples/" + config_filename + "/000000.png"
    os.makedirs(os.path.expanduser(os.path.dirname(sample_file)),
                exist_ok=True)
    sampler.sample(sample_file, args.save_samples)

    metrics = [accuracy_x_to_g, accuracy_g_to_x]
    sum_metrics = [0 for metric in metrics]
    broken = False
    for i in range(steps):
        if broken:
            break
        gan.step()

        if args.viewer and i % args.sample_every == 0:
            samples += 1
            print("Sampling " + str(samples))
            sample_file = "samples/" + config_filename + "/%06d.png" % (
                samples)
            sampler.sample(sample_file, args.save_samples)

        if i % 100 == 0:
            for k, metric in enumerate(metrics):
                _metric = metric().cpu().detach().numpy()
                sum_metrics[k] += _metric
                if not np.isfinite(_metric):
                    broken = True
                    break

    return sum_metrics
コード例 #3
0
ファイル: autoencode.py プロジェクト: vg3095/HyperGAN
def setup_gan(config, inputs, args):
    gan = AutoencoderGAN(config=config, inputs=inputs)
    gan.create()

    if (args.action != 'search' and os.path.isfile(save_file + ".meta")):
        gan.load(save_file)

    tf.train.start_queue_runners(sess=gan.session)
    GlobalViewer.enable()
    config_name = args.config
    title = "[hypergan] colorizer " + config_name
    GlobalViewer.window.set_title(title)

    return gan
コード例 #4
0
 def plot(self, image, filename, save_sample):
     """ Plot an image."""
     image = np.minimum(image, 1)
     image = np.maximum(image, -1)
     image = np.squeeze(image)
     # Scale to 0..255.
     imin, imax = image.min(), image.max()
     image = (image - imin) * 255. / (imax - imin) + .5
     image = image.astype(np.uint8)
     if save_sample:
         try:
             Image.fromarray(image).save(filename)
         except Exception as e:
             print("Warning: could not sample to ", filename, ".  Please check permissions and make sure the path exists")
             print(e)
     GlobalViewer.update(self.gan, image)
コード例 #5
0
 def plot(self, image, filename, save_sample):
     """ Plot an image."""
     image = np.minimum(image, 1)
     image = np.maximum(image, -1)
     image = np.squeeze(image)
     # Scale to 0..255.
     imin, imax = image.min(), image.max()
     image = (image - imin) * 255. / (imax - imin) + .5
     image = image.astype(np.uint8)
     if save_sample:
         try:
             Image.fromarray(image).save(filename)
         except Exception as e:
             print("Warning: could not sample to ", filename, ".  Please check permissions and make sure the path exists")
             print(e)
     GlobalViewer.update(self.gan, image)
コード例 #6
0
ファイル: 2d-distribution.py プロジェクト: likeyu21/HyperGAN
def train(config, args):
    if (args.viewer):
        GlobalViewer.enable()
        title = "[hypergan] 2d-test " + args.config
        GlobalViewer.window.set_title(title)

    with tf.device(args.device):
        config.generator['end_features'] = 2
        gan = hg.GAN(config, inputs=Custom2DInputDistribution(args))
        gan.discriminator = Custom2DDiscriminator(gan, config.discriminator)
        gan.generator = Custom2DGenerator(gan, config.generator)
        gan.encoder = gan.create_component(gan.config.encoder)
        gan.encoder.create()
        gan.generator.create()
        gan.discriminator.create()
        gan.create()

        accuracy_x_to_g = batch_accuracy(gan.inputs.x, gan.generator.sample)
        accuracy_g_to_x = batch_accuracy(gan.generator.sample, gan.inputs.x)

        sampler = Custom2DSampler(gan)

        tf.train.start_queue_runners(sess=gan.session)
        samples = 0
        steps = args.steps
        sampler.sample("samples/000000.png", args.save_samples)

        metrics = [accuracy_x_to_g, accuracy_g_to_x]
        sum_metrics = [0 for metric in metrics]
        for i in range(steps):
            gan.step()

            if args.viewer and i % args.sample_every == 0:
                samples += 1
                print("Sampling " + str(samples), args.save_samples)
                sample_file = "samples/%06d.png" % (samples)
                sampler.sample(sample_file, args.save_samples)

            if i > steps * 9.0 / 10:
                for k, metric in enumerate(gan.session.run(metrics)):
                    sum_metrics[k] += metric

        tf.reset_default_graph()
        gan.session.close()

    return sum_metrics
コード例 #7
0
    def plot_image(self, image, filename, save_sample, regularize=True):
        """ Plot an image from an external source."""
        if np.shape(image)[2] == 4:
            fmt = "RGBA"
        else:
            fmt = "RGB"
        if save_sample:
            try:
                Image.fromarray(image, fmt).save(filename)
            except Exception as e:
                print(
                    "Warning: could not sample to ", filename,
                    ".  Please check permissions and make sure the path exists"
                )
                print(e)

        GlobalViewer.update(self.gan, image)
        return image
コード例 #8
0
ファイル: next-frame.py プロジェクト: szad670401/HyperGAN
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import uuid

arg_parser = ArgumentParser("render next frame")
parser = arg_parser.add_image_arguments()
parser.add_argument('--frames',
                    type=int,
                    default=4,
                    help='Number of frames to embed.')
args = arg_parser.parse_args()

GlobalViewer.set_options(enabled=args.viewer,
                         title="[hypergan] next-frame " + args.config,
                         viewer_size=args.zoom)
width, height, channels = [int(x) for x in args.size.split('x')]

input_config = hc.Config({
    "batch_size": args.batch_size,
    "directories": [args.directory],
    "channels": channels,
    "crop": args.crop,
    "height": height,
    "random_crop": False,
    "resize": args.resize,
    "shuffle": args.action == "train",
    "width": width
})
コード例 #9
0
            for k, metric in enumerate(metrics):
                _metric = metric().cpu().detach().numpy()
                sum_metrics[k] += _metric
                if not np.isfinite(_metric):
                    broken = True
                    break

    return sum_metrics


if args.action == 'train':
    metrics = train(config, args)
    print("Resulting metrics:", metrics)
elif args.action == 'search':
    metric_sum = train(config, args)
    if 'search_output' in args:
        search_output = args.search_output
    else:
        search_output = "2d-test-results.csv"

    hc.Selector().save(config_filename, config)
    with open(search_output, "a") as myfile:
        total = sum(metric_sum)
        myfile.write(config_filename + "," +
                     ",".join([str(x)
                               for x in metric_sum]) + "," + str(total) + "\n")
else:
    print("Unknown action: " + args.action)

GlobalViewer.close()