コード例 #1
0
ファイル: 2d-sequence.py プロジェクト: jmbjr/HyperGAN
import tensorflow as tf
import hypergan as hg
import hyperchamber as hc
from hypergan.generators import *
from hypergan.search.random_search import RandomSearch
from hypergan.viewer import GlobalViewer
from common import *

arg_parser = ArgumentParser("Test your gan vs a known distribution", require_directory=False)
arg_parser.parser.add_argument('--distribution', '-t', type=str, default='circle', help='what distribution to test, options are circle, modes')
arg_parser.parser.add_argument('--sequence_length', '-n', type=int, default=2, help='how many steps to look forward')
args = arg_parser.parse_args()

config = lookup_config(args)
if args.action == 'search':
    config = RandomSearch({}).random_config()

class Sequence2DGenerator(BaseGenerator):
    def create(self):
        gan = self.gan
        config = self.config
        ops = self.ops
        end_features = config.end_features or 2*args.sequence_length

        ops.describe('custom_generator')

        net = gan.encoder.sample
        for i in range(2):
            net = ops.linear(net, 32)
            net = ops.lookup('bipolar')(net)
        net = ops.linear(net, end_features)
コード例 #2
0
            [input.reshape(self.gan.batch_size(), -1), context['digit']], 1)
        net = self.linear(net)
        net = self.relu(net)
        net = self.linear2(net)
        net = self.tanh(net)
        return net


config = lookup_config(args)

if args.action == 'search':
    search = RandomSearch({
        'generator': {
            'class': MNISTGenerator,
            'end_features': 10
        },
        'discriminator': {
            'class': MNISTDiscriminator
        }
    })

    config = search.random_config()

inputs = MNISTInputLoader(args.batch_size)


def setup_gan(config, inputs, args):
    gan = MNISTGAN(config, inputs=inputs)
    return gan

コード例 #3
0
 def test_random_config(self):
     rs = RandomSearch({})
     self.assertTrue(rs.random_config()['trainer']["class"] != None)
コード例 #4
0
 def test_trainers(self):
     rs = RandomSearch({})
     self.assertTrue(rs.trainer()["class"] != None)
コード例 #5
0
 def test_range(self):
     rs = RandomSearch({})
     self.assertTrue(isinstance(rs.range(), list))
コード例 #6
0
        ynet = tf.concat(axis=1, values=[x,y])

        net = tf.concat(axis=0, values=[ynet, gnet])
        net = ops.linear(net, 128)
        net = tf.nn.tanh(net)
        self.sample = net

        return net



config = lookup_config(args)

if args.action == 'search':
    search = RandomSearch({
        'generator': {'class': MNISTGenerator, 'end_features': 10},
        'discriminator': {'class': MNISTDiscriminator}
        })

    config = search.random_config()

mnist_loader = MNISTInputLoader(args.batch_size)

def setup_gan(config, inputs, args):
    gan = hg.GAN(config, inputs=inputs, batch_size=args.batch_size)
    return gan

def train(config, args):
    gan = setup_gan(config, mnist_loader, args)
    correct_prediction = tf.equal(tf.argmax(gan.generator.layer('fy'),1), tf.argmax(gan.inputs.y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) * 100
    metrics = [accuracy]