Example #1
0
def test_gru_cell():
    n_inputs = 3
    n_units = 4
    batch_size = 1
    inputs = tx.Input(n_units=n_inputs)

    gru0 = tx.GRUCell(inputs,
                      n_units,
                      activation=tf.tanh,
                      gate_activation=tf.sigmoid)

    # applies gate after matrix multiplication and uses
    # recurrent biases, this makes it compatible with cuDNN
    # implementation
    gru1 = GRUCell(n_units,
                   activation='tanh',
                   recurrent_activation='sigmoid',
                   reset_after=False,
                   implementation=1,
                   use_bias=True)

    assert not hasattr(gru1, "kernel")

    state0 = [s() for s in gru0.previous_state]
    #  get_initial_state from keras returns either a tuple or a single
    #  state see test_rnn_cell, but the __call__ API requires an iterable
    state1 = gru1.get_initial_state(inputs, batch_size=1)

    assert tx.tensor_equal(state1, state0[0])

    inputs.value = tf.ones([batch_size, n_inputs])

    res1 = gru1(inputs, state0)
    res1_ = gru1(inputs, state0)

    for r1, r2 in zip(res1, res1_):
        assert tx.tensor_equal(r1, r2)

    # the only difference is that keras kernels are fused together
    kernel = tf.concat([w.weights.value() for w in gru0.layer_state.w],
                       axis=-1)
    recurrent_kernel = tf.concat([u.weights for u in gru0.layer_state.u],
                                 axis=-1)
    bias = tf.concat([w.bias for w in gru0.layer_state.w], axis=-1)

    assert tx.same_shape(kernel, gru1.kernel)
    assert tx.same_shape(recurrent_kernel, gru1.recurrent_kernel)
    assert tx.same_shape(bias, gru1.bias)

    gru1.kernel = kernel
    gru1.recurrent_kernel = recurrent_kernel
    gru1.bias = bias

    res2 = gru1(inputs, state0)
    for i in range(len(res1)):
        assert not tx.tensor_equal(res1[i], res2[i])
    res0 = gru0()
    # res0_ = gru0.state[0]()
    assert tx.tensor_equal(res0, res2[0])
Example #2
0
class ScaleHierarchicalOptimizer(BaseHierarchicalPolicy):
    """Hierarchical optimizer.

    Described in
    "Learned Optimizers that Scale and Generalize" (Wichrowska et. al, 2017)

    Keyword Args
    ------------
    param_units : int
        Number of hidden units for parameter RNN.
    tensor_units : int
        Number of hidden units for tensor RNN.
    global_units : int
        Number of hidden units for global RNN.
    init_lr : float[2]
        Learning rate initialization range. Actual learning rate values are
        IID exp(unif(log(init_lr))).
    timescales : int
        Number of timescales to compute momentum for.
    epsilon : float
        Denominator epsilon for normalization operation in case input is 0.
    momentum_decay_bias_init : float
        Constant initializer for EMA momentum decay rate logit beta_g. Should
        correspond to beta_1 in an Adam teacher.
    variance_decay_bias_init : float
        Constant initializer for EMA variance decay rate logit beta_lambda.
        Should correspond to beta_2 in an Adam teacher.
    use_gradient_shortcut : bool
        Use shortcut connection adding linear transformation of momentum at
        various timescales to direction output?
    name : str
        Name of optimizer network
    **kwargs : dict
        Passed onto tf.keras.layers.GRUCell
    """

    default_name = "ScaleHierarchicalOptimizer"

    def init_layers(self,
                    param_units=10,
                    tensor_units=5,
                    global_units=5,
                    init_lr=(1e-6, 1e-2),
                    timescales=1,
                    epsilon=1e-10,
                    momentum_decay_bias_init=logit(0.9),
                    variance_decay_bias_init=logit(0.999),
                    use_gradient_shortcut=True,
                    **kwargs):
        """Initialize layers."""
        assert (init_lr[0] > 0 and init_lr[1] > 0 and epsilon > 0)
        self.timescales = timescales
        self.init_lr = init_lr
        self.epsilon = epsilon

        # Parameter, Tensor, & Global RNNs (may have different size)
        self.param_rnn = GRUCell(param_units, name="param_rnn", **kwargs)
        self.tensor_rnn = GRUCell(tensor_units, name="tensor_rnn", **kwargs)
        self.global_rnn = GRUCell(global_units, name="global_rnn", **kwargs)

        # Parameter change
        self.d_theta = Dense(1,
                             input_shape=(param_units, ),
                             name="d_theta",
                             kernel_initializer="zeros")
        # Learning rate change
        self.delta_nu = Dense(1,
                              input_shape=(param_units, ),
                              name="delta_nu",
                              kernel_initializer="zeros")
        # Momentum decay rate
        self.beta_g = Dense(1,
                            input_shape=(param_units, ),
                            kernel_initializer="zeros",
                            bias_initializer=tf.constant_initializer(
                                value=momentum_decay_bias_init),
                            activation="sigmoid",
                            name="beta_g")
        # Variance/scale decay rate
        self.beta_lambda = Dense(1,
                                 input_shape=(param_units, ),
                                 kernel_initializer="zeros",
                                 bias_initializer=tf.constant_initializer(
                                     value=variance_decay_bias_init),
                                 activation="sigmoid",
                                 name="beta_lambda")
        # Momentum shortcut
        if use_gradient_shortcut:
            self.gradient_shortcut = Dense(1,
                                           input_shape=(timescales, ),
                                           name="gradient_shortcut",
                                           kernel_initializer="zeros")
        else:
            self.gradient_shortcut = None

        # Gamma parameter
        # Stored as a logit - the actual gamma used will be sigmoid(gamma)
        self.gamma = tf.Variable(tf.zeros(()), trainable=True, name="gamma")

    def call_global(self, states, global_state, training=False):
        """Equation 12.

        Global RNN. Inputs are prepared (except for final mean) in ``call``.
        """
        # [1, units] -> [num tensors, 1, units] -> [1, units]
        inputs = tf.reduce_mean(
            tf.stack([state["tensor"] for state in states]), 0)
        global_state_new, _ = self.global_rnn(inputs, global_state)
        return global_state_new

    def _new_momentum_variance(self, grads, states, states_new):
        """Equation 1, 2, 3, 13.

        Helper function for scaled momentum update
        """
        # Base decay
        # Eq 13
        # [var size, 1] -> [*var shape]
        shape = tf.shape(grads)
        beta_g = tf.reshape(self.beta_g(states["param"]), shape)
        beta_lambda = tf.reshape(self.beta_lambda(states["param"]), shape)

        # New momentum, variance
        # Eq 1, 2
        states_new["scaling"] = [
            rms_momentum(grads,
                         g_bar,
                         lambda_,
                         beta_1=beta_g**(0.5**s),
                         beta_2=beta_lambda**(0.5**s))
            for s, (g_bar, lambda_) in enumerate(states["scaling"])
        ]

        # Scaled momentum
        _m = [
            g_bar / tf.sqrt(lambda_ + self.epsilon)
            for g_bar, lambda_ in states_new["scaling"]
        ]

        # m_t: [timescales, *var shape] -> [var size, timescales]
        return tf.transpose(tf.reshape(tf.stack(_m), [self.timescales, -1]))

    def _relative_log_gradient_magnitude(self, states, states_new):
        """Equation 4.

        Helper function for relative log gradient magnitudes
        """
        log_lambdas = tf.math.log(
            tf.stack([lambda_ for g_bar, lambda_ in states_new["scaling"]]) +
            self.epsilon)
        _gamma = log_lambdas - tf.reduce_mean(log_lambdas, axis=0)

        # gamma_t: [timescales, *var shape] -> [var size, timescales]
        return tf.transpose(tf.reshape(_gamma, [self.timescales, -1]))

    def _parameterized_change(self, param, states, states_new, m):
        """Equation 5, 7, 8.

        Helper function for parameter change explicitly parameterized into
        direction and learning rate

        Notes
        -----
        (1) Direction is no longer explicitly parameterized, as specified by
            appendix D.3 in Wichrowska et al.
        (2) A shortcut connection is include as per appendix B.1.
        """
        # New learning rate
        # Eq 7, 8
        d_eta = tf.reshape(self.delta_nu(states_new["param"]), tf.shape(param))
        eta = d_eta + states["eta_bar"]
        sg = tf.nn.sigmoid(self.gamma)
        states_new["eta_bar"] = (sg * states["eta_bar"] + (1 - sg) * eta)

        # Relative log learning rate
        # Eq Unnamed, end of sec 3.2.4
        states_new["eta_rel"] = tf.reshape(eta - tf.math.reduce_mean(eta),
                                           [-1, 1])

        # Direction
        # Eq 5, using the update given in Appendix D.3
        d_theta = self.d_theta(states_new["param"])

        if self.gradient_shortcut:
            d_theta += self.gradient_shortcut(m)

        return tf.exp(eta) * tf.reshape(d_theta, tf.shape(param))

    def call(self, param, grads, states, global_state, training=False):
        """Optimizer Update.

        Notes
        -----
        The state indices in Wichrowska et al. are incorrect, and should be:
        (1) g_bar^n, lambda^n = EMA(g_bar^n-1, g^n), EMA(lambda^n-1, g^n)
            instead of EMA(..., g^n-1), etc
        (2) h^n = RNN(x^n, h^n-1) instead of h^n+1 = RNN(x^n, h^n)
        Then, the g^n -> g_bar^n, lambda^n -> m^n -> h^n -> d^n data flow
        occurs within the same step instead of across 2 steps. This fix is
        reflected in the original Scale code.

        In order to reduce state size, the state update computation is split:
        (1) Compute beta_g, beta_lambda, m.
        (2) Update Parameter & Tensor RNN.
        (3) Compute eta, d. This step only depends on the parameter RNN,
            so the Global RNN being updated after this does not matter.
        (4) Update Global RNN.
        eta_rel is the only "transient" (i.e. not RNN hidden states, momentum,
        variance, learning rate) product stored in the optimizer state.
        """
        states_new = {}

        # Prerequisites ("Momentum and variance at various timescales")
        # Eq 1, 2, 3, 13
        m = self._new_momentum_variance(grads, states, states_new)

        # Eq 4
        gamma = self._relative_log_gradient_magnitude(states, states_new)

        # Param RNN
        # inputs = [var size, features]
        param_in = tf.concat(
            [
                # x^n:
                m,
                gamma,
                states["eta_rel"],
                # h_tensor: [1, hidden size] -> [var size, hidden size]
                tf.tile(states["tensor"], [tf.size(param), 1]),
                # h_global: [1, hidden size] -> [var size, hidden size]
                tf.tile(global_state, [tf.size(param), 1]),
            ],
            1)

        # RNN Update
        # Eq 10
        states_new["param"], _ = self.param_rnn(param_in, states["param"])
        # Eq 11
        tensor_in = tf.concat([
            tf.math.reduce_mean(states_new["param"], 0, keepdims=True),
            global_state
        ], 1)
        states_new["tensor"], _ = self.tensor_rnn(tensor_in, states["tensor"])

        # Eq 5, 7, 8
        delta_theta = self._parameterized_change(param, states, states_new, m)

        return delta_theta, states_new

    def get_initial_state(self, var):
        """Get initial model state as a dictionary."""
        batch_size = tf.size(var)

        return {
            "scaling": [(tf.zeros(tf.shape(var)), tf.zeros(tf.shape(var)))
                        for s in range(self.timescales)],
            "param":
            self.param_rnn.get_initial_state(batch_size=batch_size,
                                             dtype=tf.float32),
            "tensor":
            self.tensor_rnn.get_initial_state(batch_size=1, dtype=tf.float32),
            "eta_bar":
            tf.random.uniform(shape=tf.shape(var),
                              minval=tf.math.log(self.init_lr[0]),
                              maxval=tf.math.log(self.init_lr[1])),
            "eta_rel":
            tf.zeros([batch_size, 1]),
        }

    def get_initial_state_global(self):
        """Initialize global hidden state."""
        return self.global_rnn.get_initial_state(batch_size=1,
                                                 dtype=tf.float32)