예제 #1
0
    def tag_recurrent_dropout(self,
                              variables,
                              recurrent_dropout,
                              rng=None,
                              **hyperparameters):
        from blocks.roles import OUTPUT, has_roles
        ancestors = graph.deep_ancestors(variables)
        for lstm in self.rnn.transitions:
            variables = [
                var for var in ancestors
                if (has_roles(var, [OUTPUT]) and lstm in var.tag.annotations
                    and var.name.endswith("states"))
            ]

            # get one dropout mask for all time steps.  use the very
            # first state to get the hidden state shape, else we get
            # graph cycles.
            initial_state = util.the(
                [var for var in variables if "initial_state" in var.name])
            mask = util.get_dropout_mask(initial_state.shape,
                                         recurrent_dropout,
                                         rng=rng)

            subsequent_states = [
                var for var in variables if "initial_state" not in var.name
            ]
            graph.add_transform(subsequent_states,
                                graph.DropoutTransform("recurrent_dropout",
                                                       mask=mask),
                                reason="regularization")
예제 #2
0
    def get_updates(variables):
        # this is fugly because we must get the batch stats from the
        # graph so we get the ones that are *actually being used in
        # the computation* after graph transforms have been applied
        updates = []
        variables = graph.deep_ancestors(variables)
        for stat, role in BatchNormalization.roles.items():
            from blocks.roles import has_roles
            batch_stats = [var for var in variables if has_roles(var, [role])]
            batch_stats = util.dedup(batch_stats, equal=util.equal_computations)

            batch_stats_by_brick = OrderedDict()
            for batch_stat in batch_stats:
                brick = batch_stat.tag.batch_normalization_brick
                population_stat = brick.population_stats[stat]
                batch_stats_by_brick.setdefault(brick, []).append(batch_stat)

            for brick, batch_stats in batch_stats_by_brick.items():
                population_stat = brick.population_stats[stat]
                if len(batch_stats) > 1:
                    # makes sense for recurrent structures
                    logger.warning("averaging multiple population statistic estimates to update %s: %s"
                                   % (util.get_path(population_stat), batch_stats))
                batch_stat = T.stack(batch_stats).mean(axis=0)
                updates.append((population_stat,
                                (1 - brick.alpha) * population_stat
                                + brick.alpha * batch_stat))
        return updates
예제 #3
0
 def tag_recurrent_weight_noise(self,
                                variables,
                                rng=None,
                                **hyperparameters):
     variables = [
         var for var in graph.deep_ancestors(variables)
         if var.name == "weight_noise_goes_here"
     ]
     graph.add_transform(variables,
                         graph.WhiteNoiseTransform("recurrent_weight_noise",
                                                   rng=rng),
                         reason="regularization")
예제 #4
0
 def tag_attention_dropout(self, variables, rng=None, **hyperparameters):
     from blocks.roles import INPUT, has_roles
     bricks_ = [
         brick for brick in util.all_bricks([self.patch_transform])
         if isinstance(brick, (bricks.Linear, conv2d.Convolutional,
                               conv3d.Convolutional))
     ]
     variables = [
         var for var in graph.deep_ancestors(variables)
         if (has_roles(var, [INPUT]) and any(brick in var.tag.annotations
                                             for brick in bricks_))
     ]
     graph.add_transform(variables,
                         graph.DropoutTransform("attention_dropout",
                                                rng=rng),
                         reason="regularization")
예제 #5
0
파일: main.py 프로젝트: ChunHungLiu/tsa-rnn
def construct_monitors(algorithm, task, model, graphs, outputs, updates,
                       monitor_options, n_spatial_dims, plot_url,
                       hyperparameters, patchmonitor_interval, **kwargs):
    from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring

    extensions = []

    if "steps" in monitor_options:
        step_channels = []
        step_channels.extend([
            algorithm.steps[param].norm(2).copy(name="step_norm:%s" % name)
            for name, param in model.get_parameter_dict().items()
        ])
        step_channels.append(
            algorithm.total_step_norm.copy(name="total_step_norm"))
        step_channels.append(
            algorithm.total_gradient_norm.copy(name="total_gradient_norm"))

        from extensions import Compressor
        for step_rule in algorithm.step_rule.components:
            if isinstance(step_rule, Compressor):
                step_channels.append(
                    step_rule.norm.copy(name="compressor.norm"))
                step_channels.append(
                    step_rule.newnorm.copy(name="compressor.newnorm"))
                step_channels.append(
                    step_rule.median.copy(name="compressor.median"))
                step_channels.append(
                    step_rule.ratio.copy(name="compressor.ratio"))

        step_channels.extend(
            outputs["train"][key] for key in
            "cost emitter_cost excursion_cost cross_entropy error_rate".split(
            ))

        step_channels.extend(
            util.uniqueify_names_last_resort(
                util.dedup((
                    var.mean().copy(name="bn_stat:%s" % util.get_path(var))
                    for var in graph.deep_ancestors([outputs["train"]["cost"]])
                    if hasattr(var.tag, "batch_normalization_brick")),
                           equal=util.equal_computations)))

        logger.warning("constructing training data monitor")
        extensions.append(
            TrainingDataMonitoring(step_channels,
                                   prefix="iteration",
                                   after_batch=True))

    if "parameters" in monitor_options:
        data_independent_channels = []
        for parameter in graphs["train"].parameters:
            if parameter.name in "gamma beta W b".split():
                quantity = parameter.norm(2)
                quantity.name = "parameter.norm:%s" % util.get_path(parameter)
                data_independent_channels.append(quantity)
        for key in "location_std scale_std".split():
            data_independent_channels.append(
                hyperparameters[key].copy(name="parameter:%s" % key))
        extensions.append(
            DataStreamMonitoring(data_independent_channels,
                                 data_stream=None,
                                 after_epoch=True))

    for which_set in "train valid test".split():
        channels = []
        channels.extend(outputs[which_set][key]
                        for key in "cost emitter_cost excursion_cost".split())
        channels.extend(outputs[which_set][key]
                        for key in task.monitor_outputs())
        channels.append(
            outputs[which_set]["savings"].mean().copy(name="mean_savings"))

        if "theta" in monitor_options:
            for key in "true_scale raw_location raw_scale".split():
                for stat in "mean var".split():
                    channels.append(
                        getattr(outputs[which_set][key],
                                stat)(axis=1).copy(name="%s.%s" % (key, stat)))
        if which_set == "train":
            if "activations" in monitor_options:
                from blocks.roles import has_roles, OUTPUT
                cnn_outputs = OrderedDict()
                for var in theano.gof.graph.ancestors(
                        graphs[which_set].outputs):
                    if (has_roles(var, [OUTPUT]) and util.annotated_by_a(
                            util.get_convolution_classes(), var)):
                        cnn_outputs.setdefault(util.get_path(var),
                                               []).append(var)
                for path, vars in cnn_outputs.items():
                    vars = util.dedup(vars, equal=util.equal_computations)
                    for i, var in enumerate(vars):
                        channels.append(var.mean().copy(
                            name="activation[%i].mean:%s" % (i, path)))

        if "batch_normalization" in monitor_options:
            errors = []
            for population_stat, update in updates[which_set]:
                if population_stat.name.startswith("population"):
                    # this is a super robust way to get the
                    # corresponding batch statistic from the
                    # exponential moving average expression
                    batch_stat = update.owner.inputs[1].owner.inputs[1]
                    errors.append(((population_stat - batch_stat)**2).mean())
            if errors:
                channels.append(
                    T.stack(errors).mean().copy(
                        name="population_statistic_mse"))

        logger.warning("constructing %s monitor" % which_set)
        extensions.append(
            DataStreamMonitoring(channels,
                                 prefix=which_set,
                                 after_epoch=True,
                                 data_stream=task.get_stream(which_set,
                                                             monitor=True)))

    if "patches" in monitor_options:
        from patchmonitor import PatchMonitoring, VideoPatchMonitoring

        patchmonitor = None
        if n_spatial_dims == 2:
            patchmonitor_klass = PatchMonitoring
        elif n_spatial_dims == 3:
            patchmonitor_klass = VideoPatchMonitoring

        if patchmonitor_klass:
            for which in "train valid".split():
                patch = outputs[which]["patch"]
                patch = patch.dimshuffle(1, 0, *range(2, patch.ndim))
                patch_extractor = theano.function(
                    [outputs[which][key] for key in "x x_shape".split()], [
                        outputs[which][key]
                        for key in "raw_location raw_scale".split()
                    ] + [patch])

                patchmonitor = patchmonitor_klass(
                    save_to="%s_patches_%s" % (hyperparameters["name"], which),
                    data_stream=task.get_stream(which,
                                                shuffle=False,
                                                num_examples=10),
                    every_n_batches=patchmonitor_interval,
                    extractor=patch_extractor,
                    map_to_input_space=attention.static_map_to_input_space)
                patchmonitor.save_patches("patchmonitor_test.png")
                extensions.append(patchmonitor)

    if plot_url:
        plot_channels = []
        plot_channels.extend(task.plot_channels())
        plot_channels.append(["train_cost"])
        #plot_channels.append(["train_%s" % step_channel.name for step_channel in step_channels])

        from blocks.extras.extensions.plot import Plot
        extensions.append(
            Plot(name,
                 channels=plot_channels,
                 after_epoch=True,
                 server_url=plot_url))

    return extensions
예제 #6
0
def construct_monitors(algorithm, task, model, graphs, outputs,
                       updates, monitor_options, n_spatial_dims,
                       plot_url, hyperparameters,
                       patchmonitor_interval, **kwargs):
    from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring

    extensions = []

    if "steps" in monitor_options:
        step_channels = []
        step_channels.extend([
            algorithm.steps[param].norm(2).copy(name="step_norm:%s" % name)
            for name, param in model.get_parameter_dict().items()])
        step_channels.append(algorithm.total_step_norm.copy(name="total_step_norm"))
        step_channels.append(algorithm.total_gradient_norm.copy(name="total_gradient_norm"))

        from extensions import Compressor
        for step_rule in algorithm.step_rule.components:
            if isinstance(step_rule, Compressor):
                step_channels.append(step_rule.norm.copy(name="compressor.norm"))
                step_channels.append(step_rule.newnorm.copy(name="compressor.newnorm"))
                step_channels.append(step_rule.median.copy(name="compressor.median"))
                step_channels.append(step_rule.ratio.copy(name="compressor.ratio"))

        step_channels.extend(outputs["train"][key] for key in
                             "cost emitter_cost excursion_cost cross_entropy error_rate".split())

        step_channels.extend(util.uniqueify_names_last_resort(util.dedup(
            (var.mean().copy(name="bn_stat:%s" % util.get_path(var))
             for var in graph.deep_ancestors([outputs["train"]["cost"]])
             if hasattr(var.tag, "batch_normalization_brick")),
            equal=util.equal_computations)))

        logger.warning("constructing training data monitor")
        extensions.append(TrainingDataMonitoring(
            step_channels, prefix="iteration", after_batch=True))

    if "parameters" in monitor_options:
        data_independent_channels = []
        for parameter in graphs["train"].parameters:
            if parameter.name in "gamma beta W b".split():
                quantity = parameter.norm(2)
                quantity.name = "parameter.norm:%s" % util.get_path(parameter)
                data_independent_channels.append(quantity)
        for key in "location_std scale_std".split():
            data_independent_channels.append(hyperparameters[key].copy(name="parameter:%s" % key))
        extensions.append(DataStreamMonitoring(
            data_independent_channels, data_stream=None, after_epoch=True))

    for which_set in "train valid test".split():
        channels = []
        channels.extend(outputs[which_set][key] for key in
                        "cost emitter_cost excursion_cost".split())
        channels.extend(outputs[which_set][key] for key in
                        task.monitor_outputs())
        channels.append(outputs[which_set]["savings"]
                        .mean().copy(name="mean_savings"))

        if "theta" in monitor_options:
            for key in "true_scale raw_location raw_scale".split():
                for stat in "mean var".split():
                    channels.append(getattr(outputs[which_set][key], stat)(axis=1)
                                    .copy(name="%s.%s" % (key, stat)))
        if which_set == "train":
            if "activations" in monitor_options:
                from blocks.roles import has_roles, OUTPUT
                cnn_outputs = OrderedDict()
                for var in theano.gof.graph.ancestors(graphs[which_set].outputs):
                    if (has_roles(var, [OUTPUT]) and util.annotated_by_a(
                            util.get_convolution_classes(), var)):
                        cnn_outputs.setdefault(util.get_path(var), []).append(var)
                for path, vars in cnn_outputs.items():
                    vars = util.dedup(vars, equal=util.equal_computations)
                    for i, var in enumerate(vars):
                        channels.append(var.mean().copy(
                            name="activation[%i].mean:%s" % (i, path)))

        if "batch_normalization" in monitor_options:
            errors = []
            for population_stat, update in updates[which_set]:
                if population_stat.name.startswith("population"):
                    # this is a super robust way to get the
                    # corresponding batch statistic from the
                    # exponential moving average expression
                    batch_stat = update.owner.inputs[1].owner.inputs[1]
                    errors.append(((population_stat - batch_stat)**2).mean())
            if errors:
                channels.append(T.stack(errors).mean().copy(name="population_statistic_mse"))

        logger.warning("constructing %s monitor" % which_set)
        extensions.append(DataStreamMonitoring(
            channels, prefix=which_set, after_epoch=True,
            data_stream=task.get_stream(which_set, monitor=True)))

    if "patches" in monitor_options:
        from patchmonitor import PatchMonitoring, VideoPatchMonitoring

        patchmonitor = None
        if n_spatial_dims == 2:
            patchmonitor_klass = PatchMonitoring
        elif n_spatial_dims == 3:
            patchmonitor_klass = VideoPatchMonitoring

        if patchmonitor_klass:
            for which in "train valid".split():
                patch = outputs[which]["patch"]
                patch = patch.dimshuffle(1, 0, *range(2, patch.ndim))
                patch_extractor = theano.function(
                    [outputs[which][key] for key in "x x_shape".split()],
                    [outputs[which][key] for key in "raw_location raw_scale".split()] + [patch])

                patchmonitor = patchmonitor_klass(
                    save_to="%s_patches_%s" % (hyperparameters["name"], which),
                    data_stream=task.get_stream(which, shuffle=False, num_examples=10),
                    every_n_batches=patchmonitor_interval,
                    extractor=patch_extractor,
                    map_to_input_space=attention.static_map_to_input_space)
                patchmonitor.save_patches("patchmonitor_test.png")
                extensions.append(patchmonitor)

    if plot_url:
        plot_channels = []
        plot_channels.extend(task.plot_channels())
        plot_channels.append(["train_cost"])
        #plot_channels.append(["train_%s" % step_channel.name for step_channel in step_channels])

        from blocks.extras.extensions.plot import Plot
        extensions.append(Plot(name, channels=plot_channels,
                            after_epoch=True, server_url=plot_url))

    return extensions