コード例 #1
0
ファイル: pacgan_task.py プロジェクト: ml-lab/PacGAN
        def create_main_loop():
            model, bn_model, bn_updates = create_models()
            ali, = bn_model.top_bricks
            discriminator_loss, generator_loss = bn_model.outputs
            step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1)
            algorithm = ali_algorithm(discriminator_loss,
                                      ali.discriminator_parameters, step_rule,
                                      generator_loss, ali.generator_parameters,
                                      step_rule)
            algorithm.add_updates(bn_updates)
            streams = create_gaussian_mixture_data_streams(
                batch_size=BATCH_SIZE,
                monitoring_batch_size=MONITORING_BATCH_SIZE,
                means=MEANS,
                variances=VARIANCES,
                priors=PRIORS)
            main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
            bn_monitored_variables = ([
                v for v in bn_model.auxiliary_variables if 'norm' not in v.name
            ] + bn_model.outputs)
            monitored_variables = (
                [v
                 for v in model.auxiliary_variables if 'norm' not in v.name] +
                model.outputs)
            extensions = [
                Timing(),
                FinishAfter(after_n_epochs=NUM_EPOCHS),
                DataStreamMonitoring(bn_monitored_variables,
                                     train_monitor_stream,
                                     prefix="train",
                                     updates=bn_updates),
                DataStreamMonitoring(monitored_variables,
                                     valid_monitor_stream,
                                     prefix="valid"),
                Checkpoint(os.path.join(self._work_dir, "main_loop.tar"),
                           after_epoch=True,
                           after_training=True,
                           use_cpickle=True),
                ProgressBar(),
                Printing(),

                #ModelLogger(folder=self._work_dir, after_epoch=True),
                GraphLogger(num_modes=1,
                            num_samples=2500,
                            dimension=2,
                            r=0,
                            std=1,
                            folder=self._work_dir,
                            after_epoch=True,
                            after_training=True),
                MetricLogger(means=MEANS,
                             variances=VARIANCES,
                             folder=self._work_dir,
                             after_epoch=True)
            ]
            main_loop = MainLoop(model=bn_model,
                                 data_stream=main_loop_stream,
                                 algorithm=algorithm,
                                 extensions=extensions)
            return main_loop
コード例 #2
0
def create_main_loop(save_path):
    model, bn_model, bn_updates = create_models()
    gan, = bn_model.top_bricks
    discriminator_loss, generator_loss = bn_model.outputs
    step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1)
    algorithm = ali_algorithm(discriminator_loss, gan.discriminator_parameters,
                              step_rule, generator_loss,
                              gan.generator_parameters, step_rule)
    algorithm.add_updates(bn_updates)
    streams = create_gaussian_mixture_data_streams(
        batch_size=BATCH_SIZE, monitoring_batch_size=MONITORING_BATCH_SIZE,
        means=MEANS, variances=VARIANCES, priors=PRIORS)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
    bn_monitored_variables = (
        [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
        bn_model.outputs)
    monitored_variables = (
        [v for v in model.auxiliary_variables if 'norm' not in v.name] +
        model.outputs)
    extensions = [
        Timing(),
        FinishAfter(after_n_epochs=NUM_EPOCHS),
        DataStreamMonitoring(
            bn_monitored_variables, train_monitor_stream, prefix="train",
            updates=bn_updates),
        DataStreamMonitoring(
            monitored_variables, valid_monitor_stream, prefix="valid"),
        Checkpoint(save_path, after_epoch=True, after_training=True,
                   use_cpickle=True),
        ProgressBar(),
        Printing(),
    ]
    main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)
    return main_loop
コード例 #3
0
ファイル: mixture_viz.py プロジェクト: zueigung1419/ALI
def get_data(main_loop, n_points=1000):
    means = main_loop.data_stream.dataset.means
    variances = main_loop.data_stream.dataset.variances
    priors = main_loop.data_stream.dataset.priors
    _, _, stream = create_gaussian_mixture_data_streams(n_points, n_points,
                                                        sources=('features',
                                                                 'label'),
                                                        means=means,
                                                        variances=variances,
                                                        priors=priors)
    originals, labels = next(stream.get_epoch_iterator())
    return {'originals': originals,
            'labels': labels}