Esempio n. 1
0
    def compute_output(self, network, in_vw):
        deterministic = network.find_hyperparameter(["deterministic"])

        moving_var_type = network.find_hyperparameter(["moving_var_type"],
                                                      DEFAULT_MOVING_VAR_TYPE)
        epsilon = network.find_hyperparameter(["epsilon"], 1e-8)

        if moving_var_type == "log_var":
            moving_var_init_value = 1.0

            def transform_var(v):
                return T.log(v + epsilon)

            def untransform_var(v):
                return T.exp(v)
        elif moving_var_type == "var":
            moving_var_init_value = 0.0

            def transform_var(v):
                return v

            def untransform_var(v):
                return v
        elif moving_var_type == "inv_std":
            moving_var_init_value = 0.0

            def transform_var(v):
                return T.inv(T.sqrt(v) + epsilon)

            def untransform_var(v):
                return T.sqr(T.inv(v))

        # -----------------------------------------------
        # calculate axes to have parameters/normalization
        # -----------------------------------------------

        # axes over which there are parameters for each element
        # ie. parameter_axes == [1, 2] means shape[1] * shape[2] total
        # parameters - one for each combination of shape[1] and shape[2]
        parameter_axes = treeano.utils.find_axes(
            network,
            in_vw.ndim,
            positive_keys=["parameter_axes"],
            negative_keys=["non_parameter_axes"])
        parameter_broadcastable = tuple(
            [idx not in parameter_axes for idx in range(in_vw.ndim)])
        parameter_shape = tuple([
            1 if b else s for b, s in zip(parameter_broadcastable, in_vw.shape)
        ])
        # axes to normalize over - ie. subtract the mean across these axes
        normalization_axes = treeano.utils.find_axes(
            network,
            in_vw.ndim,
            positive_keys=["normalization_axes"],
            negative_keys=["non_normalization_axes"])
        stats_shape = tuple([
            1 if idx in normalization_axes else s
            for idx, s in enumerate(in_vw.shape)
        ])
        stats_broadcastable = tuple(
            [idx in normalization_axes for idx in range(in_vw.ndim)])
        assert all([s is not None for s in stats_shape])

        # -----------------------
        # initialize shared state
        # -----------------------

        _gamma = network.create_vw(
            name="gamma",
            is_shared=True,
            shape=parameter_shape,
            tags={"parameter", "weight"},
            # TODO try uniform init between -0.05 and 0.05
            default_inits=[],
            default_inits_hyperparameters=["gamma_inits", "inits"],
        )
        _beta = network.create_vw(
            name="beta",
            is_shared=True,
            shape=parameter_shape,
            tags={"parameter", "bias"},
            default_inits=[],
            default_inits_hyperparameters=["beta_inits", "inits"],
        )
        gamma = T.patternbroadcast(_gamma.variable, parameter_broadcastable)
        beta = T.patternbroadcast(_beta.variable, parameter_broadcastable)

        moving_mean = network.create_vw(
            name="mean",
            is_shared=True,
            shape=stats_shape,
            tags={"state"},
            default_inits=[],
        )
        moving_var = network.create_vw(
            name="var",
            is_shared=True,
            shape=stats_shape,
            tags={"state"},
            default_inits=[treeano.inits.ConstantInit(moving_var_init_value)],
        )

        # ------------------------
        # calculate input mean/var
        # ------------------------

        in_mean = T.mean(in_vw.variable,
                         axis=normalization_axes,
                         keepdims=True)
        biased_in_var = T.var(in_vw.variable,
                              axis=normalization_axes,
                              keepdims=True)
        batch_axis = network.find_hyperparameter(["batch_axis"])
        if batch_axis is None:
            in_var = biased_in_var
        else:
            batch_size = in_vw.shape[batch_axis]
            if batch_size is None:
                batch_size = in_vw.variable.shape[batch_axis]
            else:
                batch_size = np.array(batch_size)
            batch_size = batch_size.astype(fX)
            unbias_factor = treeano.utils.as_fX(batch_size / (batch_size - 1))
            in_var = unbias_factor * biased_in_var

        # save the mean/var for updating and debugging
        network.create_vw(
            name="in_mean",
            variable=in_mean,
            tags={},
            shape=stats_shape,
        )
        network.create_vw(
            name="in_var",
            variable=in_var,
            tags={},
            shape=stats_shape,
        )

        # ----------------
        # calculate output
        # ----------------

        bn_use_moving_stats = network.find_hyperparameter(
            ["bn_use_moving_stats"], False)
        if bn_use_moving_stats:
            effective_mean = T.patternbroadcast(moving_mean.variable,
                                                stats_broadcastable)
            effective_var = T.patternbroadcast(
                untransform_var(moving_var.variable), stats_broadcastable)
        else:
            if deterministic:
                msg = ("Batch normalization does not use `deterministic` flag"
                       " to control whether or not moving stats are used for"
                       " computation. In this case `bn_use_moving_stats` is "
                       "False, thus per-minibatch stats will be used and may"
                       "be stochastic (depending on how minibatches are"
                       "created), and not only a function of the input"
                       "observation")
                warnings.warn(msg)
            effective_mean = in_mean
            effective_var = in_var

        if network.find_hyperparameter(["consider_mean_constant"], False):
            effective_mean = T.consider_constant(effective_mean)
        if network.find_hyperparameter(["consider_var_constant"], False):
            effective_var = T.consider_constant(effective_var)

        epsilon = network.find_hyperparameter(["epsilon"], 1e-8)
        denom = T.sqrt(effective_var + epsilon)
        scaled = (in_vw.variable - effective_mean) / denom
        output = (1 + gamma) * scaled + beta
        network.create_vw(
            name="default",
            variable=output,
            shape=in_vw.shape,
            tags={"output"},
        )
Esempio n. 2
0
    def compute_output(self, network, in_vw):
        alpha = network.find_hyperparameter(["bttf_alpha", "alpha"], 0.95)
        epsilon = network.find_hyperparameter(["epsilon"], 1e-4)
        normalization_axes = network.find_hyperparameter(["normalization_axes"],
                                                         (1,))
        # HACK: using "deterministic" to mean test time
        deterministic = network.find_hyperparameter(["deterministic"], False)
        update_averages = network.find_hyperparameter(["update_averages"],
                                                      not deterministic)

        alpha = treeano.utils.as_fX(alpha)

        if update_averages:
            backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_with_updates
        else:
            backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_no_updates

        state_shape = tuple([in_vw.shape[axis] for axis in normalization_axes])
        state_pattern = ["x"] * in_vw.ndim
        for idx, axis in enumerate(normalization_axes):
            state_pattern[axis] = idx

        def make_state(name, tags, default_inits=None):
            if default_inits is None:
                default_inits = []
            return network.create_vw(
                name=name,
                is_shared=True,
                shape=state_shape,
                tags=tags,
                default_inits=default_inits,
            ).variable

        gamma = make_state("gamma", {"parameter", "weight"})
        beta = make_state("beta", {"parameter", "bias"})
        # mean of input
        mean = make_state("mean", {"state"})
        # gradient of mean of input
        mean_grad = make_state("mean_grad", {"state"})
        # mean of input^2
        squared_mean = make_state("squared_mean", {"state"},
                                  # initializing to 1, so that std = 1
                                  default_inits=[treeano.inits.ConstantInit(1.)])
        # gradient of mean of input^2
        squared_mean_grad = make_state("squared_mean_grad", {"state"})

        in_var = in_vw.variable
        mean_axes = tuple([axis for axis in range(in_var.ndim)
                           if axis not in normalization_axes])
        batch_mean = in_var.mean(axis=mean_axes)
        squared_batch_mean = T.sqr(in_var).mean(axis=mean_axes)

        # expectation of input (x)
        E_x = backprop_to_the_future_mean(batch_mean,
                                          mean,
                                          mean_grad,
                                          alpha)
        # TODO try mixing batch mean with E_x
        # expectation of input squared
        E_x_squared = backprop_to_the_future_mean(squared_batch_mean,
                                                  squared_mean,
                                                  squared_mean_grad,
                                                  alpha)

        # HACK mixing batch and rolling means
        # E_x = 0.5 * E_x + 0.5 * batch_mean
        # E_x_squared = 0.5 * E_x_squared + 0.5 * squared_batch_mean

        if 1:
            mu = E_x
            sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon)

            mu = mu.dimshuffle(state_pattern)
            sigma = sigma.dimshuffle(state_pattern)
            gamma = gamma.dimshuffle(state_pattern)
            beta = beta.dimshuffle(state_pattern)

        else:
            # HACK mixing current value
            E_x = E_x.dimshuffle(state_pattern)
            E_x_squared = E_x_squared.dimshuffle(state_pattern)
            gamma = gamma.dimshuffle(state_pattern)
            beta = beta.dimshuffle(state_pattern)

            E_x = 0.1 * in_var + 0.9 * E_x
            E_x_squared = 0.1 * T.sqr(in_var) + 0.9 * E_x_squared

            mu = E_x
            sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon)

        if 0:
            # HACK don't backprop through sigma
            sigma = T.consider_constant(sigma)

        if 1:
            # HACK using batch mean
            mu = batch_mean
            mu = mu.dimshuffle(state_pattern)

        if 0:
            # HACK using batch variance
            sigma = T.sqrt(in_var.var(axis=mean_axes) + epsilon)
            sigma = sigma.dimshuffle(state_pattern)

        out_var = (in_var - mu) * (T.exp(gamma) / sigma) + beta

        network.create_vw(
            name="default",
            variable=out_var,
            shape=in_vw.shape,
            tags={"output"},
        )

        if 1:
            # HACK monitoring state
            network.create_vw(
                name="mu_mean",
                variable=mu.mean(),
                shape=(),
                tags={"monitor"},

            )
            network.create_vw(
                name="sigma_mean",
                variable=sigma.mean(),
                shape=(),
                tags={"monitor"},

            )
            network.create_vw(
                name="gamma_mean",
                variable=gamma.mean(),
                shape=(),
                tags={"monitor"},

            )
            network.create_vw(
                name="beta_mean",
                variable=beta.mean(),
                shape=(),
                tags={"monitor"},

            )
Esempio n. 3
0
    def compute_output(self, network, in_vw):
        deterministic = network.find_hyperparameter(["deterministic"])

        use_log_moving_var = network.find_hyperparameter(
            ["use_log_moving_var"], DEFAULT_USE_LOG_MOVING_VAR)

        if use_log_moving_var:
            def transform_var(v):
                epsilon = network.find_hyperparameter(["epsilon"], 1e-8)
                return T.log(v + epsilon)

            def untransform_var(v):
                return T.exp(v)
        else:
            def transform_var(v):
                return v

            def untransform_var(v):
                return v

        # -----------------------------------------------
        # calculate axes to have parameters/normalization
        # -----------------------------------------------

        # axes over which there are parameters for each element
        # ie. parameter_axes == [1, 2] means shape[1] * shape[2] total
        # parameters - one for each combination of shape[1] and shape[2]
        parameter_axes = treeano.utils.find_axes(
            network,
            in_vw.ndim,
            positive_keys=["parameter_axes"],
            negative_keys=["non_parameter_axes"])
        parameter_broadcastable = tuple([idx not in parameter_axes
                                         for idx in range(in_vw.ndim)])
        parameter_shape = tuple([1 if b else s
                                 for b, s in zip(parameter_broadcastable,
                                                 in_vw.shape)])
        # axes to normalize over - ie. subtract the mean across these axes
        normalization_axes = treeano.utils.find_axes(
            network,
            in_vw.ndim,
            positive_keys=["normalization_axes"],
            negative_keys=["non_normalization_axes"])
        stats_shape = tuple([1 if idx in normalization_axes else s
                             for idx, s in enumerate(in_vw.shape)])
        stats_broadcastable = tuple([idx in normalization_axes
                                     for idx in range(in_vw.ndim)])
        assert all([s is not None for s in stats_shape])

        # -----------------------
        # initialize shared state
        # -----------------------

        gamma_inits = list(toolz.concat(network.find_hyperparameters(
            ["gamma_inits",
             "inits"],
            # TODO uniform init between 0.95 and 1.05
            [treeano.inits.ConstantInit(1.0)])))
        beta_inits = list(toolz.concat(network.find_hyperparameters(
            ["beta_inits",
             "inits"],
            [treeano.inits.ConstantInit(0.0)])))
        mean_inits = list(toolz.concat(network.find_hyperparameters(
            ["inits"],
            [])))
        var_inits = list(toolz.concat(network.find_hyperparameters(
            ["inits"],
            [treeano.inits.ConstantInit(1.0 if use_log_moving_var else 0.0)])))

        _gamma = network.create_vw(
            name="gamma",
            is_shared=True,
            shape=parameter_shape,
            tags={"parameter"},
            inits=gamma_inits,
        )
        _beta = network.create_vw(
            name="beta",
            is_shared=True,
            shape=parameter_shape,
            tags={"parameter"},
            inits=beta_inits,
        )
        gamma = T.patternbroadcast(_gamma.variable, parameter_broadcastable)
        beta = T.patternbroadcast(_beta.variable, parameter_broadcastable)

        moving_mean = network.create_vw(
            name="mean",
            is_shared=True,
            shape=stats_shape,
            tags={"state"},
            inits=mean_inits,
        )
        moving_var = network.create_vw(
            name="var",
            is_shared=True,
            shape=stats_shape,
            tags={"state"},
            inits=var_inits,
        )

        # ------------------------
        # calculate input mean/var
        # ------------------------

        in_mean = T.mean(in_vw.variable,
                         axis=normalization_axes,
                         keepdims=True)
        biased_in_var = T.var(in_vw.variable,
                              axis=normalization_axes,
                              keepdims=True)
        batch_axis = network.find_hyperparameter(["batch_axis"])
        if batch_axis is None:
            in_var = biased_in_var
        else:
            batch_size = in_vw.shape[batch_axis]
            if batch_size is None:
                batch_size = in_vw.variable.shape[batch_axis]
            else:
                batch_size = np.array(batch_size)
            batch_size = batch_size.astype(fX)
            unbias_factor = treeano.utils.as_fX(batch_size / (batch_size - 1))
            in_var = unbias_factor * biased_in_var

        assert in_mean.broadcastable == stats_broadcastable
        assert in_var.broadcastable == stats_broadcastable

        # save the mean/var for updating and debugging
        network.create_vw(
            name="in_mean",
            variable=in_mean,
            tags={},
            shape=stats_shape,
        )
        network.create_vw(
            name="in_var",
            variable=in_var,
            tags={},
            shape=stats_shape,
        )

        # ----------------
        # calculate output
        # ----------------

        bn_use_moving_stats = network.find_hyperparameter(
            ["bn_use_moving_stats"], False)
        if bn_use_moving_stats:
            effective_mean = T.patternbroadcast(moving_mean.variable,
                                                stats_broadcastable)
            effective_var = T.patternbroadcast(
                untransform_var(moving_var.variable),
                stats_broadcastable)
        else:
            if deterministic:
                msg = ("Batch normalization does not use `deterministic` flag"
                       " to control whether or not moving stats are used for"
                       " computation. In this case `bn_use_moving_stats` is "
                       "False, thus per-minibatch stats will be used and may"
                       "be stochastic (depending on how minibatches are"
                       "created), and not only a function of the input"
                       "observation")
                warnings.warn(msg)
            effective_mean = in_mean
            effective_var = in_var

        if network.find_hyperparameter(["consider_mean_constant"], False):
            effective_mean = T.consider_constant(effective_mean)
        if network.find_hyperparameter(["consider_var_constant"], False):
            effective_var = T.consider_constant(effective_var)

        epsilon = network.find_hyperparameter(["epsilon"], 1e-8)
        denom = T.sqrt(effective_var + epsilon)
        scaled = (in_vw.variable - effective_mean) / denom
        output = gamma * scaled + beta
        network.create_vw(
            name="default",
            variable=output,
            shape=in_vw.shape,
            tags={"output"},
        )
Esempio n. 4
0
    def compute_output(self, network, in_vw):
        alpha = network.find_hyperparameter(["bttf_alpha", "alpha"], 0.95)
        epsilon = network.find_hyperparameter(["epsilon"], 1e-4)
        normalization_axes = network.find_hyperparameter(
            ["normalization_axes"], (1, ))
        # HACK: using "deterministic" to mean test time
        deterministic = network.find_hyperparameter(["deterministic"], False)
        update_averages = network.find_hyperparameter(["update_averages"],
                                                      not deterministic)

        alpha = treeano.utils.as_fX(alpha)

        if update_averages:
            backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_with_updates
        else:
            backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_no_updates

        state_shape = tuple([in_vw.shape[axis] for axis in normalization_axes])
        state_pattern = ["x"] * in_vw.ndim
        for idx, axis in enumerate(normalization_axes):
            state_pattern[axis] = idx

        def make_state(name, tags, default_inits=None):
            if default_inits is None:
                default_inits = []
            return network.create_vw(
                name=name,
                is_shared=True,
                shape=state_shape,
                tags=tags,
                default_inits=default_inits,
            ).variable

        gamma = make_state("gamma", {"parameter", "weight"})
        beta = make_state("beta", {"parameter", "bias"})
        # mean of input
        mean = make_state("mean", {"state"})
        # gradient of mean of input
        mean_grad = make_state("mean_grad", {"state"})
        # mean of input^2
        squared_mean = make_state(
            "squared_mean",
            {"state"},
            # initializing to 1, so that std = 1
            default_inits=[treeano.inits.ConstantInit(1.)])
        # gradient of mean of input^2
        squared_mean_grad = make_state("squared_mean_grad", {"state"})

        in_var = in_vw.variable
        mean_axes = tuple([
            axis for axis in range(in_var.ndim)
            if axis not in normalization_axes
        ])
        batch_mean = in_var.mean(axis=mean_axes)
        squared_batch_mean = T.sqr(in_var).mean(axis=mean_axes)

        # expectation of input (x)
        E_x = backprop_to_the_future_mean(batch_mean, mean, mean_grad, alpha)
        # TODO try mixing batch mean with E_x
        # expectation of input squared
        E_x_squared = backprop_to_the_future_mean(squared_batch_mean,
                                                  squared_mean,
                                                  squared_mean_grad, alpha)

        # HACK mixing batch and rolling means
        # E_x = 0.5 * E_x + 0.5 * batch_mean
        # E_x_squared = 0.5 * E_x_squared + 0.5 * squared_batch_mean

        if 1:
            mu = E_x
            sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon)

            mu = mu.dimshuffle(state_pattern)
            sigma = sigma.dimshuffle(state_pattern)
            gamma = gamma.dimshuffle(state_pattern)
            beta = beta.dimshuffle(state_pattern)

        else:
            # HACK mixing current value
            E_x = E_x.dimshuffle(state_pattern)
            E_x_squared = E_x_squared.dimshuffle(state_pattern)
            gamma = gamma.dimshuffle(state_pattern)
            beta = beta.dimshuffle(state_pattern)

            E_x = 0.1 * in_var + 0.9 * E_x
            E_x_squared = 0.1 * T.sqr(in_var) + 0.9 * E_x_squared

            mu = E_x
            sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon)

        if 0:
            # HACK don't backprop through sigma
            sigma = T.consider_constant(sigma)

        if 1:
            # HACK using batch mean
            mu = batch_mean
            mu = mu.dimshuffle(state_pattern)

        if 0:
            # HACK using batch variance
            sigma = T.sqrt(in_var.var(axis=mean_axes) + epsilon)
            sigma = sigma.dimshuffle(state_pattern)

        out_var = (in_var - mu) * (T.exp(gamma) / sigma) + beta

        network.create_vw(
            name="default",
            variable=out_var,
            shape=in_vw.shape,
            tags={"output"},
        )

        if 1:
            # HACK monitoring state
            network.create_vw(
                name="mu_mean",
                variable=mu.mean(),
                shape=(),
                tags={"monitor"},
            )
            network.create_vw(
                name="sigma_mean",
                variable=sigma.mean(),
                shape=(),
                tags={"monitor"},
            )
            network.create_vw(
                name="gamma_mean",
                variable=gamma.mean(),
                shape=(),
                tags={"monitor"},
            )
            network.create_vw(
                name="beta_mean",
                variable=beta.mean(),
                shape=(),
                tags={"monitor"},
            )