def loss_and_next_state(self, current_state, loss_state=None):
        params = current_state.phi_var_dict
        loss = self.inner_loss(current_state, loss_state)

        grads_dict = self.loss_module.gradients(loss, params)
        grads = grads_dict.values()

        next_rolling_features = self.rolling_features.next_state(
            current_state.rolling_features, grads)

        next_phi_vars = \
          self.theta_mod.compute_update_and_next_state(next_rolling_features,
                                                       py_utils.eqzip(grads, params.values()),
                                                       current_state.training_step)

        new_phi_var_dict = collections.OrderedDict(
            py_utils.eqzip(current_state.phi_var_dict.keys(), next_phi_vars))

        next_state = LearnerState(phi_var_dict=new_phi_var_dict,
                                  rolling_features=next_rolling_features,
                                  training_step=current_state.training_step +
                                  1,
                                  initial_loss=current_state.initial_loss)

        next_state = nest.map_structure(tf.identity, next_state)

        return loss, next_state
 def body(i, tas):
     batch = state_fn()
     out_tas = []
     for ta, b in py_utils.eqzip(nest.flatten(tas),
                                 nest.flatten(batch)):
         out_tas.append(ta.write(i, b))
     return (i + 1, nest.pack_sequence_as(dummy_batch, out_tas))
Example #3
0
    def next_state(self, state, grads):
        pad_decay = tf.expand_dims(self.decays, 0)
        new_ms_list, new_rms_list = [], []
        for ms, rms, g, var_shape in py_utils.eqzip(state.ms, state.rms, grads,
                                                    self.var_shapes):

            def single_update(grad, ms, rms):
                grad = tf.reshape(grad, [-1, 1])
                new_ms = ms * pad_decay + grad * (1 - pad_decay)
                if self.include_rms:
                    new_rms = rms * pad_decay + tf.square(grad) * (1 -
                                                                   pad_decay)
                    return new_ms, new_rms
                else:
                    return new_ms, rms

            if isinstance(g, tf.IndexedSlices):
                # pylint: disable=unbalanced-tuple-unpacking
                new_ms, new_rms = indexed_slices_apply_dense2(
                    single_update, var_shape, g, [ms, rms], 2)
            else:
                new_ms, new_rms = single_update(g, ms, rms)

            new_ms_list.append(new_ms)
            new_rms_list.append(new_rms)
        return RollingFeaturesState(ms=new_ms_list, rms=new_rms_list)
 def assign_state(self, state):
     # This also assigns the loss module's state.
     current = self.current_state()
     nest.assert_same_structure(current, state)
     current_flat = nest.flatten(current)
     state_flat = nest.flatten(state)
     assign_ops = [
         v.assign(s) for v, s in py_utils.eqzip(current_flat, state_flat)
     ]
     assign_ops.append(self.training_step.assign(state.training_step))
     return tf.group(assign_ops, name="assign_state")
def assign_variables(targets, values):
    """Creates an Op that assigns a list of values to a target variables.

  Args:
    targets: list or structure of tf.Variable
    values: list or structure of tf.Tensor

  Returns:
    tf.Operation that performs the assignment
  """
    return tf.group(*itertools.starmap(
        tf.assign, py_utils.eqzip(nest.flatten(targets),
                                  nest.flatten(values))),
                    name="assign_variables")
    def compute_update_and_next_state(self, rolling_features, grads_and_vars,
                                      training_step):
        new_vars = []

        normalizer = utils.SecondMomentNormalizer(name="Normalizer")
        mod = snt.nets.MLP([self.hidden_size] * self.hidden_layer + [2],
                           name="MLP")

        for (g, v), m, rms in py_utils.eqzip(grads_and_vars,
                                             rolling_features.ms,
                                             rolling_features.rms):

            def do_update(g, flat_v, m, rms):
                """Do a single tensor's update."""
                flat_g = tf.reshape(g, [-1, 1])

                rsqrt = tf.rsqrt(rms + 1e-6)
                norm_g = m * rsqrt

                inp = tf.concat([flat_g, norm_g, flat_v, m, rms, rsqrt], 1)

                inp = normalizer(inp, is_training=True)

                step = utils.tanh_embedding(training_step)
                stack_step = tf.tile(tf.reshape(step, [1, -1]),
                                     tf.stack([tf.shape(flat_g)[0], 1]))

                inp = tf.concat([inp, stack_step], axis=1)

                output = mod(inp)

                direction = output[:, 0:1]
                magnitude = output[:, 1:2]

                step = direction * tf.exp(
                    magnitude * self.magnitude_rate) * self.step_multiplier

                new_flat_v = flat_v - step
                return new_flat_v,

            flat_v = tf.reshape(v, [-1, 1])
            if isinstance(g, tf.IndexedSlices):
                new_flat_v, = utils.indexed_slices_apply_dense2(
                    do_update, v.shape.as_list(), g, [flat_v, m, rms], 1)
            else:
                new_flat_v, = do_update(g, flat_v, m, rms)

            new_vars.append(tf.reshape(new_flat_v, v.shape))

        return new_vars
 def _nest_bimap(self, fn, data1, data2):
     data = py_utils.eqzip(nest.flatten(data1), nest.flatten(data2))
     out = [fn(*a) for a in data]
     return nest.pack_sequence_as(data1, out)
Example #8
0
def indexed_slices_apply_dense2(fn, var_shape, g_inp, dense_var_inp, n_outs):
    """Helper function to work with sparse tensors.

  dense_var_inp has the leading 2 dimensions collapsed forming shape [n_words *
  n_words_feat, n_feat]
  g_inp on the otherhand is [n_words, n_words_feat]
  var_shape is static and is [n_words, n_words_feat]

  Arguments:
    fn: (gradient: tf.Tensor, *var_args: tf.Tensor) -> [tf.Tensor]
    var_shape: list
    g_inp: tf.IndexedSlices
    dense_var_inp: tf.Tensor list.
    n_outs: int
  Returns:
    dense outputs
  """
    grad_idx, grad_value = accumulate_sparse_gradients(g_inp)

    n_words, n_word_feat = var_shape

    args = []
    for a_possibly_nest in dense_var_inp:

        def do_on_tensor(a):
            n_feat = a.shape.as_list()[1]
            n_active = tf.size(grad_idx)
            reshaped = tf.reshape(a, [n_words, n_word_feat, n_feat])
            sub_reshaped = tf.gather(reshaped, grad_idx)
            return tf.reshape(sub_reshaped, [n_active * n_word_feat, n_feat])

        args.append(nest.map_structure(do_on_tensor, a_possibly_nest))

    returns = fn(grad_value, *args)

    def undo((full_val, sub_val)):
        """Undo the slices."""
        if tf.shape(full_val).shape.as_list()[0] != 2:
            raise NotImplementedError(
                "TODO(lmetz) other than this is not implemented.")
        n_words, n_word_feat = var_shape
        _, n_feat = sub_val.shape.as_list()
        n_active = tf.size(grad_idx)

        shape = [n_active, n_word_feat * n_feat]
        in_shape_form = tf.reshape(sub_val, shape)

        new_shape = [n_words, n_word_feat * n_feat]
        mask_shape = [n_words, n_word_feat * n_feat]

        scattered = tf.scatter_nd(tf.reshape(tf.to_int32(grad_idx), [-1, 1]),
                                  in_shape_form,
                                  shape=new_shape)
        mask = tf.scatter_nd(tf.reshape(tf.to_int32(grad_idx), [-1, 1]),
                             tf.ones_like(in_shape_form),
                             shape=mask_shape)

        # put back into the flat format
        scattered = tf.reshape(scattered, [n_words * n_word_feat, n_feat])
        mask = tf.reshape(mask, [n_words * n_word_feat, n_feat])

        # this is the update part / fake scatter_update but with gradients.
        return full_val * (1 - mask) + scattered * mask

    dense_outs = []
    for ret, dense_v in list(py_utils.eqzip(returns, dense_var_inp[0:n_outs])):
        flat_out = map(
            undo, py_utils.eqzip(nest.flatten(dense_v), nest.flatten(ret)))
        dense_outs.append(nest.pack_sequence_as(dense_v, flat_out))
    return dense_outs
Example #9
0
    def train_op(self, ds_state):
        """Train with ES + Grads."""

        perturbs = ds_state.perturbation
        rp_grads = ds_state.grads
        meta_loss = ds_state.meta_loss
        antith_meta_loss = ds_state.antith_meta_loss

        # convert the [bs] shaped tensors to something like [bs, 1, 1, ...].
        broadcast_loss = [
            tf.reshape(meta_loss, [-1] + [1] * (len(p.shape.as_list()) - 1))
            for p in perturbs
        ]
        broadcast_antith_loss = [
            tf.reshape(antith_meta_loss,
                       [-1] + [1] * (len(p.shape.as_list()) - 1))
            for p in perturbs
        ]

        # ES gradient:
        # f(x+s) * d/ds(log(p(s))) = f(x+s) * s / (std**2)
        # for antith:
        # (f(x+s) - f(x-s))*s/(2 * std**2)
        es_grads = []
        for pos_loss, neg_loss, perturb in py_utils.eqzip(
                broadcast_loss, broadcast_antith_loss, perturbs):
            # this is the same as having 2 samples.
            es_grads.append(
                (pos_loss - neg_loss) * perturb / (self.custom_getter.std**2))

        def mean_and_var(g):
            mean = tf.reduce_mean(g, axis=0, keep_dims=True)
            square_sum = tf.reduce_sum(tf.square((g - mean)), axis=0)
            var = square_sum / (g.shape.as_list()[0] - 1)
            return tf.squeeze(mean, 0), var + 1e-8

        def combine(es, rp):
            """Do inverse variance rescaling."""
            mean_es, var_es = mean_and_var(es)
            mean_rp, var_rp = mean_and_var(rp)

            es_var_inv = 1. / var_es
            rp_var_inv = 1. / var_rp

            den = es_var_inv + rp_var_inv
            combine_g = (mean_es * es_var_inv + mean_rp * rp_var_inv) / den

            weight_es = es_var_inv / den

            return combine_g, weight_es

        combine_grads, _ = zip(
            *
            [combine(es, rp) for es, rp in py_utils.eqzip(es_grads, rp_grads)])

        grads_vars = py_utils.eqzip(combine_grads,
                                    self.learner.theta_mod.get_variables())

        grads_vars = common.clip_grads_vars(grads_vars,
                                            self.gradient_clip_by_value)
        grads_vars = common.assert_grads_vars_not_nan(grads_vars)

        self._did_use_getter_on_all_variables()

        with tf.device(self.remote_device):
            train_op = self.meta_opt.apply_gradients(grads_vars)

        with tf.control_dependencies([train_op]):
            op = common.assert_post_update_not_nan(grads_vars)
            return tf.group(train_op, op, name="train_op")