def run(discriminative_regularization=True):
    streams = create_celeba_streams(training_batch_size=100,
                                    monitoring_batch_size=500,
                                    include_targets=False)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    rval = create_training_computation_graphs(discriminative_regularization)
    cg, bn_cg, variance_parameters = rval
    pop_updates = list(
        set(get_batch_normalization_updates(bn_cg, allow_duplicates=True)))
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    model = Model(bn_cg.outputs[0])
    selector = Selector(
        find_bricks(
            model.top_bricks,
            lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp',
                                         'decoder_convnet', 'decoder_mlp')))
    parameters = list(selector.get_parameters().values()) + variance_parameters

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_cg.outputs[0],
                                parameters=parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    monitored_quantities_list = []
    for graph in [bn_cg, cg]:
        cost, kl_term, reconstruction_term = graph.outputs
        cost.name = 'nll_upper_bound'
        avg_kl_term = kl_term.mean(axis=0)
        avg_kl_term.name = 'avg_kl_term'
        avg_reconstruction_term = -reconstruction_term.mean(axis=0)
        avg_reconstruction_term.name = 'avg_reconstruction_term'
        monitored_quantities_list.append(
            [cost, avg_kl_term, avg_reconstruction_term])
    train_monitoring = DataStreamMonitoring(
        monitored_quantities_list[0], train_monitor_stream, prefix="train",
        updates=extra_updates, after_epoch=False, before_first_epoch=False,
        every_n_epochs=5)
    valid_monitoring = DataStreamMonitoring(
        monitored_quantities_list[1], valid_monitor_stream, prefix="valid",
        after_epoch=False, before_first_epoch=False, every_n_epochs=5)

    # Prepare checkpoint
    save_path = 'celeba_vae_{}regularization.zip'.format(
        '' if discriminative_regularization else 'no_')
    checkpoint = Checkpoint(save_path, every_n_epochs=5, use_cpickle=True)

    extensions = [Timing(), FinishAfter(after_n_epochs=75), train_monitoring,
                  valid_monitoring, checkpoint, Printing(), ProgressBar()]
    main_loop = MainLoop(data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)
    main_loop.run()
Exemple #2
0
 def test_find_first_level(self):
     found = set(find_bricks([self.mlp], lambda x: isinstance(x, Sequence)))
     assert len(found) == 5
     assert self.mlp in found
     found.remove(self.mlp)
     sequences = set(self.mlp.activations[0:2] +
                     [self.mlp.activations[3],
                      self.mlp.activations[3].children[0]])
     assert sequences == found
Exemple #3
0
 def test_find_first_level(self):
     found = set(find_bricks([self.mlp], lambda x: isinstance(x, Sequence)))
     assert len(found) == 5
     assert self.mlp in found
     found.remove(self.mlp)
     sequences = set(
         self.mlp.activations[0:2] +
         [self.mlp.activations[3], self.mlp.activations[3].children[0]])
     assert sequences == found
Exemple #4
0
def training_noise(*bricks):
    r"""Context manager to run noise layers in "training mode".
    """
    # Avoid circular imports.
    from blocks.bricks import BatchNormalization

    bn = find_bricks(bricks, lambda b: isinstance(b, NoiseLayer))
    # Can't use either nested() (deprecated) nor ExitStack (not available
    # on Python 2.7). Well, that sucks.
    try:
        for brick in bn:
            brick.__enter__()
        yield
    finally:
        for brick in bn[::-1]:
            brick.__exit__()
Exemple #5
0
def training_noise(*bricks):
    r"""Context manager to run noise layers in "training mode".
    """
    # Avoid circular imports.
    from blocks.bricks import BatchNormalization

    bn = find_bricks(bricks, lambda b: isinstance(b, NoiseLayer))
    # Can't use either nested() (deprecated) nor ExitStack (not available
    # on Python 2.7). Well, that sucks.
    try:
        for brick in bn:
            brick.__enter__()
        yield
    finally:
        for brick in bn[::-1]:
            brick.__exit__()
Exemple #6
0
 def test_find_zeroth_level_repeated(self):
     found = find_bricks([self.mlp, self.mlp], lambda x: isinstance(x, MLP))
     assert len(found) == 1
     assert found[0] == self.mlp
Exemple #7
0
def batch_normalization(*bricks):
    r"""Context manager to run batch normalization in "training mode".

    Parameters
    ----------
    \*bricks
        One or more bricks which will be inspected for descendant
        instances of :class:`~blocks.bricks.BatchNormalization`.

    Notes
    -----
    Graph replacement using :func:`apply_batch_normalization`, while
    elegant, can lead to Theano graphs that are quite large and result
    in very slow compiles. This provides an alternative mechanism for
    building the batch normalized training graph. It can be somewhat
    less convenient as it requires building the graph twice if one
    wishes to monitor the output of the inference graph during training.

    Examples
    --------
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.

    >>> import theano
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
    >>> from blocks.initialization import Constant, IsotropicGaussian
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
    ...                          weights_init=IsotropicGaussian(0.1),
    ...                          biases_init=Constant(0))
    >>> mlp.initialize()

    Now, we'll construct an output variable as we would normally. This
    is getting normalized by the *population* statistics, which by
    default are initialized to 0 (mean) and 1 (standard deviation),
    respectively.

    >>> x = theano.tensor.matrix()
    >>> y = mlp.apply(x)

    And now, to construct an output with batch normalization enabled,
    i.e. normalizing pre-activations using per-minibatch statistics, we
    simply make a similar call inside of a `with` statement:

    >>> with batch_normalization(mlp):
    ...     y_bn = mlp.apply(x)

    Let's verify that these two graphs behave differently on the
    same data:

    >>> import numpy
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
    >>> inf_y = y.eval({x: data})
    >>> trn_y = y_bn.eval({x: data})
    >>> numpy.allclose(inf_y, trn_y)
    False

    """
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
    # Can't use either nested() (deprecated) nor ExitStack (not available
    # on Python 2.7). Well, that sucks.
    try:
        for brick in bn:
            brick.__enter__()
        yield
    finally:
        for brick in bn[::-1]:
            brick.__exit__()
Exemple #8
0
def run(discriminative_regularization=True):
    streams = create_celeba_streams(training_batch_size=100,
                                    monitoring_batch_size=500,
                                    include_targets=False)
    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    rval = create_training_computation_graphs(discriminative_regularization)
    cg, bn_cg, variance_parameters = rval
    pop_updates = list(
        set(get_batch_normalization_updates(bn_cg, allow_duplicates=True)))
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    model = Model(bn_cg.outputs[0])
    selector = Selector(
        find_bricks(
            model.top_bricks, lambda brick: brick.name in
            ('encoder_convnet', 'encoder_mlp', 'decoder_convnet', 'decoder_mlp'
             )))
    parameters = list(selector.get_parameters().values()) + variance_parameters

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_cg.outputs[0],
                                parameters=parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    monitored_quantities_list = []
    for graph in [bn_cg, cg]:
        cost, kl_term, reconstruction_term = graph.outputs
        cost.name = 'nll_upper_bound'
        avg_kl_term = kl_term.mean(axis=0)
        avg_kl_term.name = 'avg_kl_term'
        avg_reconstruction_term = -reconstruction_term.mean(axis=0)
        avg_reconstruction_term.name = 'avg_reconstruction_term'
        monitored_quantities_list.append(
            [cost, avg_kl_term, avg_reconstruction_term])
    train_monitoring = DataStreamMonitoring(monitored_quantities_list[0],
                                            train_monitor_stream,
                                            prefix="train",
                                            updates=extra_updates,
                                            after_epoch=False,
                                            before_first_epoch=False,
                                            every_n_epochs=5)
    valid_monitoring = DataStreamMonitoring(monitored_quantities_list[1],
                                            valid_monitor_stream,
                                            prefix="valid",
                                            after_epoch=False,
                                            before_first_epoch=False,
                                            every_n_epochs=5)

    # Prepare checkpoint
    save_path = 'celeba_vae_{}regularization.zip'.format(
        '' if discriminative_regularization else 'no_')
    checkpoint = Checkpoint(save_path, every_n_epochs=5, use_cpickle=True)

    extensions = [
        Timing(),
        FinishAfter(after_n_epochs=75), train_monitoring, valid_monitoring,
        checkpoint,
        Printing(),
        ProgressBar()
    ]
    main_loop = MainLoop(data_stream=main_loop_stream,
                         algorithm=algorithm,
                         extensions=extensions)
    main_loop.run()
Exemple #9
0
 def test_find_first_and_second_and_third_level(self):
     found = set(find_bricks([self.mlp], lambda x: isinstance(x, Logistic)))
     assert self.mlp.activations[2] in found
     assert self.mlp.activations[1].children[0] in found
     assert self.mlp.activations[3].children[0].children[0]
Exemple #10
0
def batch_normalization(*bricks):
    r"""Context manager to run batch normalization in "training mode".

    Parameters
    ----------
    \*bricks
        One or more bricks which will be inspected for descendant
        instances of :class:`~blocks.bricks.BatchNormalization`.

    Notes
    -----
    Graph replacement using :func:`apply_batch_normalization`, while
    elegant, can lead to Theano graphs that are quite large and result
    in very slow compiles. This provides an alternative mechanism for
    building the batch normalized training graph. It can be somewhat
    less convenient as it requires building the graph twice if one
    wishes to monitor the output of the inference graph during training.

    Examples
    --------
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.

    >>> import theano
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
    >>> from blocks.initialization import Constant, IsotropicGaussian
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
    ...                          weights_init=IsotropicGaussian(0.1),
    ...                          biases_init=Constant(0))
    >>> mlp.initialize()

    Now, we'll construct an output variable as we would normally. This
    is getting normalized by the *population* statistics, which by
    default are initialized to 0 (mean) and 1 (standard deviation),
    respectively.

    >>> x = theano.tensor.matrix()
    >>> y = mlp.apply(x)

    And now, to construct an output with batch normalization enabled,
    i.e. normalizing pre-activations using per-minibatch statistics, we
    simply make a similar call inside of a `with` statement:

    >>> with batch_normalization(mlp):
    ...     y_bn = mlp.apply(x)

    Let's verify that these two graphs behave differently on the
    same data:

    >>> import numpy
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
    >>> inf_y = y.eval({x: data})
    >>> trn_y = y_bn.eval({x: data})
    >>> numpy.allclose(inf_y, trn_y)
    False

    """
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
    # Can't use either nested() (deprecated) nor ExitStack (not available
    # on Python 2.7). Well, that sucks.
    try:
        for brick in bn:
            brick.__enter__()
        yield
    finally:
        for brick in bn[::-1]:
            brick.__exit__()
Exemple #11
0
 def test_find_zeroth_level_repeated(self):
     found = find_bricks([self.mlp, self.mlp], lambda x: isinstance(x, MLP))
     assert len(found) == 1
     assert found[0] == self.mlp
Exemple #12
0
 def test_find_second_and_third_level(self):
     found = set(find_bricks([self.mlp], lambda x: isinstance(x, Identity)))
     assert len(found) == 2
     assert self.mlp.activations[0].children[0] in found
     assert self.mlp.activations[1].children[1] in found
Exemple #13
0
def run(batch_size, save_path, z_dim, oldmodel, discriminative_regularization,
        classifier, vintage, monitor_every, monitor_before, checkpoint_every,
        dataset, color_convert, image_size, net_depth, subdir,
        reconstruction_factor, kl_factor, discriminative_factor, disc_weights,
        num_epochs):

    if dataset:
        streams = create_custom_streams(filename=dataset,
                                        training_batch_size=batch_size,
                                        monitoring_batch_size=batch_size,
                                        include_targets=False,
                                        color_convert=color_convert)
    else:
        streams = create_celeba_streams(training_batch_size=batch_size,
                                        monitoring_batch_size=batch_size,
                                        include_targets=False)

    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    rval = create_training_computation_graphs(z_dim, image_size, net_depth,
                                              discriminative_regularization,
                                              classifier, vintage,
                                              reconstruction_factor, kl_factor,
                                              discriminative_factor,
                                              disc_weights)
    cg, bn_cg, variance_parameters = rval

    pop_updates = list(
        set(get_batch_normalization_updates(bn_cg, allow_duplicates=True)))
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    model = Model(bn_cg.outputs[0])

    selector = Selector(
        find_bricks(
            model.top_bricks, lambda brick: brick.name in
            ('encoder_convnet', 'encoder_mlp', 'decoder_convnet', 'decoder_mlp'
             )))
    parameters = list(selector.get_parameters().values()) + variance_parameters

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_cg.outputs[0],
                                parameters=parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    sys.setrecursionlimit(1000000)

    monitored_quantities_list = []
    for graph in [bn_cg, cg]:
        # cost, kl_term, reconstruction_term, discriminative_term = graph.outputs
        cost, kl_term, reconstruction_term, discriminative_term = graph.outputs[:
                                                                                4]
        discriminative_layer_terms = graph.outputs[4:]

        cost.name = 'nll_upper_bound'
        avg_kl_term = kl_term.mean(axis=0)
        avg_kl_term.name = 'avg_kl_term'
        avg_reconstruction_term = -reconstruction_term.mean(axis=0)
        avg_reconstruction_term.name = 'avg_reconstruction_term'
        avg_discriminative_term = discriminative_term.mean(axis=0)
        avg_discriminative_term.name = 'avg_discriminative_term'

        num_layer_terms = len(discriminative_layer_terms)
        avg_discriminative_layer_terms = [None] * num_layer_terms
        for i, term in enumerate(discriminative_layer_terms):
            avg_discriminative_layer_terms[i] = discriminative_layer_terms[
                i].mean(axis=0)
            avg_discriminative_layer_terms[
                i].name = "avg_discriminative_term_layer_{:02d}".format(i)

        monitored_quantities_list.append([
            cost, avg_kl_term, avg_reconstruction_term, avg_discriminative_term
        ] + avg_discriminative_layer_terms)

    train_monitoring = DataStreamMonitoring(monitored_quantities_list[0],
                                            train_monitor_stream,
                                            prefix="train",
                                            updates=extra_updates,
                                            after_epoch=False,
                                            before_first_epoch=monitor_before,
                                            every_n_epochs=monitor_every)
    valid_monitoring = DataStreamMonitoring(monitored_quantities_list[1],
                                            valid_monitor_stream,
                                            prefix="valid",
                                            after_epoch=False,
                                            before_first_epoch=monitor_before,
                                            every_n_epochs=monitor_every)

    # Prepare checkpoint
    checkpoint = Checkpoint(save_path,
                            every_n_epochs=checkpoint_every,
                            before_training=True,
                            use_cpickle=True)

    sample_checkpoint = SampleCheckpoint(interface=DiscGenModel,
                                         z_dim=z_dim / 2,
                                         image_size=(image_size, image_size),
                                         channels=3,
                                         dataset=dataset,
                                         split="valid",
                                         save_subdir=subdir,
                                         before_training=True,
                                         after_epoch=True)
    # TODO: why does z_dim=foo become foo/2?
    extensions = [
        Timing(),
        FinishAfter(after_n_epochs=num_epochs), checkpoint, sample_checkpoint,
        train_monitoring, valid_monitoring,
        Printing(),
        ProgressBar()
    ]
    main_loop = MainLoop(model=model,
                         data_stream=main_loop_stream,
                         algorithm=algorithm,
                         extensions=extensions)

    if oldmodel is not None:
        print("Initializing parameters with old model {}".format(oldmodel))
        try:
            saved_model = load(oldmodel)
        except AttributeError:
            # newer version of blocks
            with open(oldmodel, 'rb') as src:
                saved_model = load(src)
        main_loop.model.set_parameter_values(
            saved_model.model.get_parameter_values())
        del saved_model

    main_loop.run()
Exemple #14
0
 def test_find_none(self):
     found = find_bricks([self.mlp], lambda _: False)
     assert len(found) == 0
Exemple #15
0
 def test_find_all_unique(self):
     found = find_bricks([self.mlp, self.mlp] + list(self.mlp.children),
                         lambda _: True)
     assert len(found) == 16  # 12 activations plus 4 linear transformations
def train_extractive_qa(new_training_job, config, save_path, params,
                        fast_start, fuel_server, seed):
    if seed:
        logger.debug("Changing Fuel random seed to {}".format(seed))
        fuel.config.default_seed = seed
        logger.debug("Changing Blocks random seed to {}".format(seed))
        blocks.config.config.default_seed = seed

    root_path = os.path.join(save_path, 'training_state')
    extension = '.tar'
    tar_path = root_path + extension
    best_tar_path = root_path + '_best' + extension

    c = config
    data, qam = initialize_data_and_model(c)

    if theano.config.compute_test_value != 'off':
        test_value_data = next(
            data.get_stream('train', shuffle=True, batch_size=4,
                            max_length=5).get_epoch_iterator(as_dict=True))
        for var in qam.input_vars.values():
            var.tag.test_value = test_value_data[var.name]

    costs = qam.apply_with_default_vars()
    cost = rename(costs.mean(), 'mean_cost')

    cg = Model(cost)
    if params:
        logger.debug("Load parameters from {}".format(params))
        with open(params) as src:
            cg.set_parameter_values(load_parameters(src))

    length = rename(qam.contexts.shape[1], 'length')
    batch_size = rename(qam.contexts.shape[0], 'batch_size')
    predicted_begins, = VariableFilter(name='predicted_begins')(cg)
    predicted_ends, = VariableFilter(name='predicted_ends')(cg)
    exact_match, = VariableFilter(name='exact_match')(cg)
    exact_match_ratio = rename(exact_match.mean(), 'exact_match_ratio')
    context_unk_ratio, = VariableFilter(name='context_unk_ratio')(cg)
    monitored_vars = [
        length, batch_size, cost, exact_match_ratio, context_unk_ratio
    ]
    if c['dict_path']:
        def_unk_ratio, = VariableFilter(name='def_unk_ratio')(cg)
        num_definitions = rename(qam.input_vars['defs'].shape[0],
                                 'num_definitions')
        max_definition_length = rename(qam.input_vars['defs'].shape[1],
                                       'max_definition_length')
        monitored_vars.extend(
            [def_unk_ratio, num_definitions, max_definition_length])
        if c['def_word_gating'] == 'self_attention':
            def_gates = VariableFilter(name='def_gates')(cg)
            def_gates_min = tensor.minimum(*[x.min() for x in def_gates])
            def_gates_max = tensor.maximum(*[x.max() for x in def_gates])
            monitored_vars.extend([
                rename(def_gates_min, 'def_gates_min'),
                rename(def_gates_max, 'def_gates_max')
            ])
    text_match_ratio = TextMatchRatio(data_path=os.path.join(
        fuel.config.data_path[0], 'squad/dev-v1.1.json'),
                                      requires=[
                                          predicted_begins, predicted_ends,
                                          tensor.ltensor3('contexts_text'),
                                          tensor.lmatrix('q_ids')
                                      ],
                                      name='text_match_ratio')

    parameters = cg.get_parameter_dict()
    trained_parameters = parameters.values()
    if c['embedding_path']:
        logger.debug("Exclude  word embeddings from the trained parameters")
        trained_parameters = [
            p for p in trained_parameters if not p == qam.embeddings_var()
        ]
    if c['train_only_def_part']:
        def_reading_parameters = qam.def_reading_parameters()
        trained_parameters = [
            p for p in trained_parameters if p in def_reading_parameters
        ]

    bricks = find_bricks([qam], lambda brick: isinstance(brick, Initializable))
    init_schemes = {}
    for brick in bricks:
        brick_name = "/".join([b.name for b in brick.get_unique_path()])
        init_schemes[brick_name] = {}
        for arg, value in brick.__dict__.items():
            if arg.endswith('_init'):
                init_schemes[brick_name][arg] = repr(value)
    logger.info("Initialization schemes:")
    logger.info(json.dumps(init_schemes, indent=2))

    logger.info("Cost parameters" + "\n" + pprint.pformat([
        " ".join(
            (key, str(parameters[key].get_value().shape),
             'trained' if parameters[key] in trained_parameters else 'frozen',
             str((6 / sum(parameters[key].get_value().shape))**0.5)))
        for key in sorted(parameters.keys())
    ],
                                                          width=120))

    # apply dropout to the training cost and to all the variables
    # that we monitor during training
    train_cost = cost
    train_monitored_vars = list(monitored_vars)
    if c['dropout']:
        regularized_cg = ComputationGraph([cost] + train_monitored_vars)
        # TODO: don't access private attributes
        bidir_outputs, = VariableFilter(bricks=[qam._bidir],
                                        roles=[OUTPUT])(cg)
        readout_layers = VariableFilter(bricks=[Rectifier], roles=[OUTPUT])(cg)
        dropout_vars = [bidir_outputs] + readout_layers
        logger.debug("applying dropout to {}".format(", ".join(
            [v.name for v in dropout_vars])))
        if c['dropout_type'] == 'same_mask':
            regularized_cg = apply_dropout2(regularized_cg, dropout_vars,
                                            c['dropout'])
        elif c['dropout_type'] == 'regular':
            regularized_cg = apply_dropout(regularized_cg, dropout_vars,
                                           c['dropout'])
        else:
            raise ValueError()
        # a new dropout with exactly same mask at different steps
        emb_vars = VariableFilter(roles=[EMBEDDINGS])(regularized_cg)
        emb_dropout_mask = get_dropout_mask(emb_vars[0], c['emb_dropout'])
        if c['emb_dropout_type'] == 'same_mask':
            regularized_cg = apply_dropout2(regularized_cg,
                                            emb_vars,
                                            c['emb_dropout'],
                                            dropout_mask=emb_dropout_mask)
        elif c['emb_dropout_type'] == 'regular':
            regularized_cg = apply_dropout(regularized_cg, emb_vars,
                                           c['emb_dropout'])
        else:
            raise ValueError()
        train_cost = regularized_cg.outputs[0]
        train_monitored_vars = regularized_cg.outputs[1:]

    rules = []
    if c['grad_clip_threshold']:
        rules.append(StepClipping(c['grad_clip_threshold']))
    rules.append(Adam(learning_rate=c['learning_rate'], beta1=c['momentum']))
    algorithm = GradientDescent(cost=train_cost,
                                parameters=trained_parameters,
                                step_rule=CompositeRule(rules))

    if c['grad_clip_threshold']:
        train_monitored_vars.append(algorithm.total_gradient_norm)
    if c['monitor_parameters']:
        train_monitored_vars.extend(parameter_stats(parameters, algorithm))

    training_stream = data.get_stream('train',
                                      batch_size=c['batch_size'],
                                      shuffle=True,
                                      max_length=c['max_length'])
    original_training_stream = training_stream
    if fuel_server:
        # the port will be configured by the StartFuelServer extension
        training_stream = ServerDataStream(
            sources=training_stream.sources,
            produces_examples=training_stream.produces_examples)

    extensions = [
        LoadNoUnpickling(tar_path, load_iteration_state=True,
                         load_log=True).set_conditions(
                             before_training=not new_training_job),
        StartFuelServer(original_training_stream,
                        os.path.join(save_path, 'stream.pkl'),
                        before_training=fuel_server),
        Timing(every_n_batches=c['mon_freq_train']),
        TrainingDataMonitoring(train_monitored_vars,
                               prefix="train",
                               every_n_batches=c['mon_freq_train']),
    ]
    validation = DataStreamMonitoring(
        [text_match_ratio] + monitored_vars,
        data.get_stream('dev',
                        batch_size=c['batch_size_valid'],
                        raw_text=True,
                        q_ids=True),
        prefix="dev").set_conditions(before_training=not fast_start,
                                     after_epoch=True)
    dump_predictions = DumpPredictions(save_path,
                                       text_match_ratio,
                                       before_training=not fast_start,
                                       after_epoch=True)
    track_the_best_exact = TrackTheBest(
        validation.record_name(exact_match_ratio),
        choose_best=max).set_conditions(before_training=True, after_epoch=True)
    track_the_best_text = TrackTheBest(
        validation.record_name(text_match_ratio),
        choose_best=max).set_conditions(before_training=True, after_epoch=True)
    extensions.extend([
        validation, dump_predictions, track_the_best_exact, track_the_best_text
    ])
    # We often use pretrained word embeddings and we don't want
    # to load and save them every time. To avoid that, we use
    # save_main_loop=False, we only save the trained parameters,
    # and we save the log and the iterations state separately
    # in the tar file.
    extensions.extend([
        Checkpoint(tar_path,
                   parameters=trained_parameters,
                   save_main_loop=False,
                   save_separately=['log', 'iteration_state'],
                   before_training=not fast_start,
                   every_n_epochs=c['save_freq_epochs'],
                   every_n_batches=c['save_freq_batches'],
                   after_training=not fast_start).add_condition(
                       ['after_batch', 'after_epoch'],
                       OnLogRecord(track_the_best_text.notification_name),
                       (best_tar_path, )),
        DumpTensorflowSummaries(save_path,
                                after_epoch=True,
                                every_n_batches=c['mon_freq_train'],
                                after_training=True),
        RetrievalPrintStats(retrieval=data._retrieval,
                            every_n_batches=c['mon_freq_train'],
                            before_training=not fast_start),
        Printing(after_epoch=True, every_n_batches=c['mon_freq_train']),
        FinishAfter(after_n_batches=c['n_batches'],
                    after_n_epochs=c['n_epochs']),
        Annealing(c['annealing_learning_rate'],
                  after_n_epochs=c['annealing_start_epoch']),
        LoadNoUnpickling(best_tar_path,
                         after_n_epochs=c['annealing_start_epoch'])
    ])

    main_loop = MainLoop(algorithm,
                         training_stream,
                         model=Model(cost),
                         extensions=extensions)
    main_loop.run()
Exemple #17
0
 def test_find_first_and_second_and_third_level(self):
     found = set(find_bricks([self.mlp], lambda x: isinstance(x, Logistic)))
     assert self.mlp.activations[2] in found
     assert self.mlp.activations[1].children[0] in found
     assert self.mlp.activations[3].children[0].children[0]
Exemple #18
0
 def test_find_second_and_third_level(self):
     found = set(find_bricks([self.mlp], lambda x: isinstance(x, Identity)))
     assert len(found) == 2
     assert self.mlp.activations[0].children[0] in found
     assert self.mlp.activations[1].children[1] in found
Exemple #19
0
 def test_find_none(self):
     found = find_bricks([self.mlp], lambda _: False)
     assert len(found) == 0
Exemple #20
0
def run(batch_size, save_path, z_dim, oldmodel, discriminative_regularization,
        classifier, vintage, monitor_every, monitor_before, checkpoint_every, dataset, color_convert,
        image_size, net_depth, subdir,
        reconstruction_factor, kl_factor, discriminative_factor, disc_weights,
        num_epochs):

    if dataset:
        streams = create_custom_streams(filename=dataset,
                                        training_batch_size=batch_size,
                                        monitoring_batch_size=batch_size,
                                        include_targets=False,
                                        color_convert=color_convert)
    else:
        streams = create_celeba_streams(training_batch_size=batch_size,
                                        monitoring_batch_size=batch_size,
                                        include_targets=False)

    main_loop_stream, train_monitor_stream, valid_monitor_stream = streams[:3]

    # Compute parameter updates for the batch normalization population
    # statistics. They are updated following an exponential moving average.
    rval = create_training_computation_graphs(
                z_dim, image_size, net_depth, discriminative_regularization, classifier,
                vintage, reconstruction_factor, kl_factor, discriminative_factor, disc_weights)
    cg, bn_cg, variance_parameters = rval

    pop_updates = list(
        set(get_batch_normalization_updates(bn_cg, allow_duplicates=True)))
    decay_rate = 0.05
    extra_updates = [(p, m * decay_rate + p * (1 - decay_rate))
                     for p, m in pop_updates]

    model = Model(bn_cg.outputs[0])

    selector = Selector(
        find_bricks(
            model.top_bricks,
            lambda brick: brick.name in ('encoder_convnet', 'encoder_mlp',
                                         'decoder_convnet', 'decoder_mlp')))
    parameters = list(selector.get_parameters().values()) + variance_parameters

    # Prepare algorithm
    step_rule = Adam()
    algorithm = GradientDescent(cost=bn_cg.outputs[0],
                                parameters=parameters,
                                step_rule=step_rule)
    algorithm.add_updates(extra_updates)

    # Prepare monitoring
    sys.setrecursionlimit(1000000)

    monitored_quantities_list = []
    for graph in [bn_cg, cg]:
        # cost, kl_term, reconstruction_term, discriminative_term = graph.outputs
        cost, kl_term, reconstruction_term, discriminative_term = graph.outputs[:4]
        discriminative_layer_terms = graph.outputs[4:]

        cost.name = 'nll_upper_bound'
        avg_kl_term = kl_term.mean(axis=0)
        avg_kl_term.name = 'avg_kl_term'
        avg_reconstruction_term = -reconstruction_term.mean(axis=0)
        avg_reconstruction_term.name = 'avg_reconstruction_term'
        avg_discriminative_term = discriminative_term.mean(axis=0)
        avg_discriminative_term.name = 'avg_discriminative_term'

        num_layer_terms = len(discriminative_layer_terms)
        avg_discriminative_layer_terms = [None] * num_layer_terms
        for i, term in enumerate(discriminative_layer_terms):
            avg_discriminative_layer_terms[i] = discriminative_layer_terms[i].mean(axis=0)
            avg_discriminative_layer_terms[i].name = "avg_discriminative_term_layer_{:02d}".format(i)

        monitored_quantities_list.append(
            [cost, avg_kl_term, avg_reconstruction_term,
             avg_discriminative_term] + avg_discriminative_layer_terms)

    train_monitoring = DataStreamMonitoring(
        monitored_quantities_list[0], train_monitor_stream, prefix="train",
        updates=extra_updates, after_epoch=False, before_first_epoch=monitor_before,
        every_n_epochs=monitor_every)
    valid_monitoring = DataStreamMonitoring(
        monitored_quantities_list[1], valid_monitor_stream, prefix="valid",
        after_epoch=False, before_first_epoch=monitor_before,
        every_n_epochs=monitor_every)

    # Prepare checkpoint
    checkpoint = Checkpoint(save_path, every_n_epochs=checkpoint_every,
                            before_training=True, use_cpickle=True)

    sample_checkpoint = SampleCheckpoint(interface=DiscGenModel, z_dim=z_dim/2,
                            image_size=(image_size, image_size), channels=3,
                            dataset=dataset, split="valid", save_subdir=subdir,
                            before_training=True, after_epoch=True)
    # TODO: why does z_dim=foo become foo/2?
    extensions = [Timing(),
                  FinishAfter(after_n_epochs=num_epochs),
                  checkpoint,
                  sample_checkpoint,
                  train_monitoring, valid_monitoring, 
                  Printing(),
                  ProgressBar()]
    main_loop = MainLoop(model=model, data_stream=main_loop_stream,
                         algorithm=algorithm, extensions=extensions)

    if oldmodel is not None:
        print("Initializing parameters with old model {}".format(oldmodel))
        try:
            saved_model = load(oldmodel)
        except AttributeError:
            # newer version of blocks
            with open(oldmodel, 'rb') as src:
                saved_model = load(src)
        main_loop.model.set_parameter_values(
            saved_model.model.get_parameter_values())
        del saved_model

    main_loop.run()
Exemple #21
0
 def test_find_all_unique(self):
     found = find_bricks([self.mlp, self.mlp] + list(self.mlp.children),
                         lambda _: True)
     assert len(found) == 16  # 12 activations plus 4 linear transformations