コード例 #1
0
ファイル: classify_mnist.py プロジェクト: lza93/luchador
def _main():
    args = _parase_command_line_args()
    _initialize_logger(args.debug)

    data_format = luchador.get_nn_conv_format()
    batch_size = 32
    input_shape = (
        [batch_size, 28, 28, 1] if data_format == 'NHWC' else
        [batch_size, 1, 28, 28]
    )

    classifier = _build_model(args.model, input_shape, batch_size)
    dataset = _load_data(args.mnist, data_format)

    session = nn.Session()
    session.initialize()

    summary = nn.SummaryWriter(output_dir='tmp')
    if session.graph:
        summary.add_graph(session.graph)

    try:
        _train(session, classifier, dataset['train'], batch_size)
        _test(session, classifier, dataset['test'], batch_size)
    except KeyboardInterrupt:
        pass
コード例 #2
0
def _main():
    args = _parase_command_line_args()
    initialize_logger(args.debug)

    batch_size = 32
    data_format = luchador.get_nn_conv_format()
    autoencoder = _build_model(args.model, data_format, batch_size)
    dataset = load_mnist(args.dataset, data_format=data_format, mock=args.mock)

    sess = nn.Session()
    sess.initialize()

    if args.output:
        summary = nn.SummaryWriter(output_dir=args.output)
        if sess.graph is not None:
            summary.add_graph(sess.graph)

    def _train_ae():
        batch = dataset.train.next_batch(batch_size).data
        return sess.run(
            inputs={autoencoder.input: batch},
            outputs=autoencoder.output['error'],
            updates=autoencoder.get_update_operations(),
            name='train_autoencoder',
        )

    def _plot_reconstruction(epoch):
        if not args.output:
            return
        orig = dataset.test.next_batch(batch_size).data
        recon = sess.run(
            inputs={autoencoder.input: orig},
            outputs=autoencoder.output['reconstruction'],
            name='reconstruct_images',
        )
        axis = 3 if data_format == 'NHWC' else 1
        orig = np.squeeze(orig, axis=axis)
        recon = np.squeeze(recon, axis=axis)

        base_path = os.path.join(args.output, '{:03}_'.format(epoch))
        plot_images(orig, base_path + 'orign.png')
        plot_images(recon, base_path + 'recon.png')

    _train(_train_ae,
           _plot_reconstruction,
           n_iterations=args.n_iterations,
           n_epochs=args.n_epochs)
コード例 #3
0
ファイル: classify_mnist.py プロジェクト: mot0/luchador
def _main():
    args = _parase_command_line_args()
    initialize_logger(args.debug)

    batch_size = 32
    data_format = luchador.get_nn_conv_format()
    classifier = _build_model(args.model, data_format)
    dataset = load_mnist(args.dataset, data_format=data_format, mock=args.mock)

    sess = nn.Session()
    sess.initialize()

    if args.output:
        summary = nn.SummaryWriter(output_dir=args.output)
        if sess.graph is not None:
            summary.add_graph(sess.graph)

    def _train_classifier():
        batch = dataset.train.next_batch(batch_size)
        return sess.run(
            inputs={
                classifier.input['data']: batch.data,
                classifier.input['label']: batch.label,
            },
            outputs=classifier.output['error'],
            updates=classifier.get_update_operations(),
            name='train_classifier',
        )

    def _test_classifier():
        batch = dataset.test.next_batch(batch_size)
        return sess.run(
            inputs={
                classifier.input['data']: batch.data,
                classifier.input['label']: batch.label,
            },
            outputs=classifier.output['error'],
            name='test_classifier',
        )

    _train(_train_classifier,
           _test_classifier,
           n_iterations=args.n_iterations,
           n_epochs=args.n_epochs)
コード例 #4
0
ファイル: run_autoencoder.py プロジェクト: lza93/luchador
def _main():
    args = _parase_command_line_args()
    _initialize_logger(args.debug)

    data_format = luchador.get_nn_conv_format()
    batch_size = 32
    input_shape = (
        [batch_size, 28, 28, 1] if data_format == 'NHWC' else
        [batch_size, 1, 28, 28]
    )

    autoencoder = _build_model(args.model, input_shape)
    images = _load_data(args.mnist, data_format)

    session = nn.Session()
    session.initialize()

    summary = nn.SummaryWriter(output_dir='tmp')
    if session.graph:
        summary.add_graph(session.graph)

    try:
        _train(session, autoencoder, images['train'], batch_size)
    except KeyboardInterrupt:
        pass

    orig = images['test'][:batch_size, ...]
    recon = session.run(
        outputs=autoencoder.output['reconstruction'],
        inputs={autoencoder.input: orig}
    )

    axis = 3 if data_format == 'NHWC' else 1
    original = 255 * np.squeeze(orig, axis=axis)
    recon = 255 * np.squeeze(recon, axis=axis)

    if not args.no_plot:
        _plot(original.astype('uint8'), recon.astype('uint8'))
コード例 #5
0
ファイル: train_gan.py プロジェクト: mot0/luchador
def _main():
    args = _parse_command_line_args()
    initialize_logger(args.debug)

    batch_size = 32
    dataset = load_mnist(args.dataset, flatten=True, mock=args.mock)

    model = _build_models(args.model)
    discriminator, generator = model['discriminator'], model['generator']

    input_gen = nn.Input(shape=(None, args.n_seeds), name='GeneratorInput')
    data_real = nn.Input(shape=dataset.train.shape, name='InputData')
    data_fake = generator(input_gen)

    logit_fake = discriminator(data_fake)
    logit_real = discriminator(data_real)

    gen_loss, disc_loss = _build_loss(logit_real, logit_fake)
    opt_gen, opt_disc = _build_optimization(generator, gen_loss, discriminator,
                                            disc_loss)

    sess = nn.Session()
    sess.initialize()

    if args.output:
        summary = nn.SummaryWriter(output_dir=args.output)
        if sess.graph is not None:
            summary.add_graph(sess.graph)

    def _train_disc():
        return sess.run(
            inputs={
                input_gen: _sample_seed(batch_size, args.n_seeds),
                data_real: dataset.train.next_batch(batch_size).data
            },
            outputs=disc_loss,
            updates=opt_disc,
            name='train_discriminator',
        )

    def _train_gen():
        return sess.run(
            inputs={
                input_gen: _sample_seed(batch_size, args.n_seeds),
            },
            outputs=gen_loss,
            updates=opt_gen,
            name='train_generator',
        )

    def _plot_samples(epoch):
        if not args.output:
            return
        images = sess.run(
            inputs={
                input_gen: _sample_seed(16, args.n_seeds),
            },
            outputs=data_fake,
            name='generate_samples',
        ).reshape(-1, 28, 28)
        path = os.path.join(args.output, '{:03d}.png'.format(epoch))
        plot_images(images, path)

    _train(_train_disc, _train_gen, _plot_samples, args.n_iterations,
           args.n_epochs)
コード例 #6
0
def _main():
    args = _parse_command_line_args()
    initialize_logger(args.debug)

    batch_size = 32
    format_ = luchador.get_nn_conv_format()
    dataset = load_celeba_face(args.dataset,
                               data_format=format_,
                               mock=args.mock)

    model = _build_models(args.model)
    discriminator, generator = model['discriminator'], model['generator']

    input_gen = nn.Input(shape=(None, args.n_seeds), name='GeneratorInput')
    data_shape = (None, ) + dataset.train.shape[1:]
    data_real = nn.Input(shape=data_shape, name='InputData')
    _LG.info('Building Generator')
    data_fake = generator(input_gen)

    _LG.info('Building fake discriminator')
    logit_fake = discriminator(data_fake)
    _LG.info('Building real discriminator')
    logit_real = discriminator(data_real)

    gen_loss, disc_loss = _build_loss(logit_real, logit_fake)
    opt_gen, opt_disc = _build_optimization(generator, gen_loss, discriminator,
                                            disc_loss)

    sess = nn.Session()
    sess.initialize()

    _summary_writer = None
    if args.output:
        _summary_writer = nn.SummaryWriter(output_dir=args.output)
        if sess.graph is not None:
            _summary_writer.add_graph(sess.graph)

    def _train_disc():
        return sess.run(
            inputs={
                input_gen: _sample_seed(batch_size, args.n_seeds),
                data_real: dataset.train.next_batch(batch_size).data
            },
            outputs=disc_loss,
            updates=discriminator.get_update_operations() + [opt_disc],
            name='train_discriminator',
        )

    def _train_gen():
        return sess.run(
            inputs={
                input_gen: _sample_seed(batch_size, args.n_seeds),
            },
            outputs=gen_loss,
            updates=generator.get_update_operations() + [opt_gen],
            name='train_generator',
        )

    random_seed = _sample_seed(batch_size, args.n_seeds)

    def _summarize(epoch, losses=None):
        if not args.output:
            return

        if losses:
            _summary_writer.summarize(
                summary_type='scalar',
                global_step=epoch,
                dataset={
                    'Generator/Loss': losses[0],
                    'Discriminator/Loss': losses[1],
                },
            )

        images = sess.run(
            inputs={
                input_gen: random_seed,
            },
            outputs=data_fake,
            name='generate_samples',
        )
        if format_ == 'NCHW':
            images = images.transpose(0, 2, 3, 1)
        images = (255 * images).astype(np.uint8)
        _summary_writer.summarize(
            summary_type='image',
            global_step=epoch,
            dataset={'Genearated/epoch_{:02d}'.format(epoch): images},
            max_outputs=10,
        )

    _train(
        _train_disc,
        _train_gen,
        _summarize,
        args.n_iterations,
        args.n_epochs,
    )