Exemple #1
0
def create_main_loop(save_path):
    model, bn_model, bn_updates = create_models()
    ali, = bn_model.top_bricks
    discriminator_loss, generator_loss = bn_model.outputs

    step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1)
    algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters,
                              step_rule, generator_loss,
                              ali.generator_parameters, step_rule)
    algorithm.add_updates(bn_updates)
    streams = create_tiny_imagenet_data_streams(BATCH_SIZE,
                                                MONITORING_BATCH_SIZE)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
    bn_monitored_variables = (
        [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
        bn_model.outputs)
    monitored_variables = (
        [v for v in model.auxiliary_variables if 'norm' not in v.name] +
        model.outputs)
    extensions = [
        Timing(),
        FinishAfter(after_n_epochs=NUM_EPOCHS),
        DataStreamMonitoring(
            bn_monitored_variables, train_monitor_stream, prefix="train",
            updates=bn_updates),
        DataStreamMonitoring(
            monitored_variables, valid_monitor_stream, prefix="valid"),
        Checkpoint(save_path, after_epoch=True, after_training=True,
                   use_cpickle=True),
        ProgressBar(),
        Printing(),
    ]
    main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)
    return main_loop
Exemple #2
0
def create_main_loop(save_path):
    model, bn_model, bn_updates = create_models()
    ali, = bn_model.top_bricks
    discriminator_loss, generator_loss = bn_model.outputs

    step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1)
    algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters,
                              step_rule, generator_loss,
                              ali.generator_parameters, step_rule)
    algorithm.add_updates(bn_updates)
    streams = create_celeba_data_streams(BATCH_SIZE, MONITORING_BATCH_SIZE)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
    bn_monitored_variables = (
        [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
        bn_model.outputs)
    monitored_variables = (
        [v for v in model.auxiliary_variables if 'norm' not in v.name] +
        model.outputs)
    extensions = [
        Timing(),
        FinishAfter(after_n_epochs=NUM_EPOCHS),
        DataStreamMonitoring(
            bn_monitored_variables, train_monitor_stream, prefix="train",
            updates=bn_updates),
        DataStreamMonitoring(
            monitored_variables, valid_monitor_stream, prefix="valid"),
        Checkpoint(save_path, after_epoch=True, after_training=True,
                   use_cpickle=True),
        ProgressBar(),
        Printing(),
    ]
    main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)
    return main_loop
Exemple #3
0
        def create_main_loop():
            model, bn_model, bn_updates = create_models()
            ali, = bn_model.top_bricks
            discriminator_loss, generator_loss = bn_model.outputs
            step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1)
            algorithm = ali_algorithm(discriminator_loss,
                                      ali.discriminator_parameters, step_rule,
                                      generator_loss, ali.generator_parameters,
                                      step_rule)
            algorithm.add_updates(bn_updates)
            streams = create_gaussian_mixture_data_streams(
                batch_size=BATCH_SIZE,
                monitoring_batch_size=MONITORING_BATCH_SIZE,
                means=MEANS,
                variances=VARIANCES,
                priors=PRIORS)
            main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
            bn_monitored_variables = ([
                v for v in bn_model.auxiliary_variables if 'norm' not in v.name
            ] + bn_model.outputs)
            monitored_variables = (
                [v
                 for v in model.auxiliary_variables if 'norm' not in v.name] +
                model.outputs)
            extensions = [
                Timing(),
                FinishAfter(after_n_epochs=NUM_EPOCHS),
                DataStreamMonitoring(bn_monitored_variables,
                                     train_monitor_stream,
                                     prefix="train",
                                     updates=bn_updates),
                DataStreamMonitoring(monitored_variables,
                                     valid_monitor_stream,
                                     prefix="valid"),
                Checkpoint(os.path.join(self._work_dir, "main_loop.tar"),
                           after_epoch=True,
                           after_training=True,
                           use_cpickle=True),
                ProgressBar(),
                Printing(),

                #ModelLogger(folder=self._work_dir, after_epoch=True),
                GraphLogger(num_modes=1,
                            num_samples=2500,
                            dimension=2,
                            r=0,
                            std=1,
                            folder=self._work_dir,
                            after_epoch=True,
                            after_training=True),
                MetricLogger(means=MEANS,
                             variances=VARIANCES,
                             folder=self._work_dir,
                             after_epoch=True)
            ]
            main_loop = MainLoop(model=bn_model,
                                 data_stream=main_loop_stream,
                                 algorithm=algorithm,
                                 extensions=extensions)
            return main_loop
Exemple #4
0
def create_main_loop(save_path):

    model, bn_model, bn_updates = create_models()
    ali, = bn_model.top_bricks
    discriminator_loss, generator_loss = bn_model.outputs

    step_rule = Adam(learning_rate=LEARNING_RATE, beta1=BETA1)
    algorithm = ali_algorithm(discriminator_loss, ali.discriminator_parameters,
                              step_rule, generator_loss,
                              ali.generator_parameters, step_rule)
    algorithm.add_updates(bn_updates)
    streams = create_cifar10_data_streams(BATCH_SIZE, MONITORING_BATCH_SIZE)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams

    for d in main_loop_stream.get_epoch_iterator(as_dict=True):
        print(d.keys)
        print(d['features'].shape, d['features'].dtype)
        break

    main_loop_stream = ShapesDataset(num_examples=600, img_size=32, min_diameter=3, seed=1234).create_stream(batch_size=BATCH_SIZE, is_train=True)

    for d in main_loop_stream.get_epoch_iterator(as_dict=True):
        print(d.keys)
        print(d['features'].shape, d['features'].dtype)
        break


    train_monitor_stream = ShapesDataset(num_examples=100, img_size=32, min_diameter=3, seed=1234).create_stream(batch_size=BATCH_SIZE, is_train=False)
    valid_monitor_stream = ShapesDataset(num_examples=100, img_size=32, min_diameter=3, seed=5678).create_stream(batch_size=BATCH_SIZE, is_train=False)
    bn_monitored_variables = (
        [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
        bn_model.outputs)
    monitored_variables = (
        [v for v in model.auxiliary_variables if 'norm' not in v.name] +
        model.outputs)
    extensions = [
        Timing(),
        FinishAfter(after_n_epochs=NUM_EPOCHS),
        DataStreamMonitoring(
            bn_monitored_variables, train_monitor_stream, prefix="train",
            updates=bn_updates),
        DataStreamMonitoring(
            monitored_variables, valid_monitor_stream, prefix="valid"),
        Checkpoint(save_path, after_epoch=True, after_training=True,
                   use_cpickle=True),
        ProgressBar(),
        Printing(),
    ]
    main_loop = MainLoop(model=bn_model, data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)
    return main_loop
Exemple #5
0
 def create_main_loop(self):
     model, bn_model, bn_updates = self.create_models()
     gan, = bn_model.top_bricks
     discriminator_loss, generator_loss = bn_model.outputs
     step_rule = Adam(learning_rate=self._config["learning_rate"],
                      beta1=self._config["beta1"])
     algorithm = ali_algorithm(discriminator_loss,
                               gan.discriminator_parameters, step_rule,
                               generator_loss, gan.generator_parameters,
                               step_rule)
     algorithm.add_updates(bn_updates)
     streams = create_packing_gaussian_mixture_data_streams(
         num_packings=self._config["num_packing"],
         batch_size=self._config["batch_size"],
         monitoring_batch_size=self._config["monitoring_batch_size"],
         means=self._config["x_mode_means"],
         variances=self._config["x_mode_variances"],
         priors=self._config["x_mode_priors"],
         num_examples=self._config["num_sample"])
     main_loop_stream, train_monitor_stream, valid_monitor_stream = streams
     bn_monitored_variables = (
         [v for v in bn_model.auxiliary_variables if 'norm' not in v.name] +
         bn_model.outputs)
     monitored_variables = (
         [v for v in model.auxiliary_variables if 'norm' not in v.name] +
         model.outputs)
     extensions = [
         Timing(),
         FinishAfter(after_n_epochs=self._config["num_epoch"]),
         DataStreamMonitoring(bn_monitored_variables,
                              train_monitor_stream,
                              prefix="train",
                              updates=bn_updates),
         DataStreamMonitoring(monitored_variables,
                              valid_monitor_stream,
                              prefix="valid"),
         Checkpoint(os.path.join(self._work_dir,
                                 self._config["main_loop_file"]),
                    after_epoch=True,
                    after_training=True,
                    use_cpickle=True),
         ProgressBar(),
         Printing(),
     ]
     if self._config["log_models"]:
         extensions.append(
             ModelLogger(folder=self._work_dir, after_epoch=True))
     if self._config["log_figures"]:
         extensions.append(
             GraphLogger(num_modes=self._config["num_zmode"],
                         num_samples=self._config["num_log_figure_sample"],
                         dimension=self._config["num_zdim"],
                         r=self._config["z_mode_r"],
                         std=self._config["z_mode_std"],
                         folder=self._work_dir,
                         after_epoch=True,
                         after_training=True))
     if self._config["log_metrics"]:
         extensions.append(
             MetricLogger(means=self._config["x_mode_means"],
                          variances=self._config["x_mode_variances"],
                          folder=self._work_dir,
                          after_epoch=True))
     main_loop = MainLoop(model=bn_model,
                          data_stream=main_loop_stream,
                          algorithm=algorithm,
                          extensions=extensions)
     return main_loop