def loss_and_next_state(self, current_state, loss_state=None): params = current_state.phi_var_dict loss = self.inner_loss(current_state, loss_state) grads_dict = self.loss_module.gradients(loss, params) grads = grads_dict.values() next_rolling_features = self.rolling_features.next_state( current_state.rolling_features, grads) next_phi_vars = \ self.theta_mod.compute_update_and_next_state(next_rolling_features, py_utils.eqzip(grads, params.values()), current_state.training_step) new_phi_var_dict = collections.OrderedDict( py_utils.eqzip(current_state.phi_var_dict.keys(), next_phi_vars)) next_state = LearnerState(phi_var_dict=new_phi_var_dict, rolling_features=next_rolling_features, training_step=current_state.training_step + 1, initial_loss=current_state.initial_loss) next_state = nest.map_structure(tf.identity, next_state) return loss, next_state
def body(i, tas): batch = state_fn() out_tas = [] for ta, b in py_utils.eqzip(nest.flatten(tas), nest.flatten(batch)): out_tas.append(ta.write(i, b)) return (i + 1, nest.pack_sequence_as(dummy_batch, out_tas))
def next_state(self, state, grads): pad_decay = tf.expand_dims(self.decays, 0) new_ms_list, new_rms_list = [], [] for ms, rms, g, var_shape in py_utils.eqzip(state.ms, state.rms, grads, self.var_shapes): def single_update(grad, ms, rms): grad = tf.reshape(grad, [-1, 1]) new_ms = ms * pad_decay + grad * (1 - pad_decay) if self.include_rms: new_rms = rms * pad_decay + tf.square(grad) * (1 - pad_decay) return new_ms, new_rms else: return new_ms, rms if isinstance(g, tf.IndexedSlices): # pylint: disable=unbalanced-tuple-unpacking new_ms, new_rms = indexed_slices_apply_dense2( single_update, var_shape, g, [ms, rms], 2) else: new_ms, new_rms = single_update(g, ms, rms) new_ms_list.append(new_ms) new_rms_list.append(new_rms) return RollingFeaturesState(ms=new_ms_list, rms=new_rms_list)
def assign_state(self, state): # This also assigns the loss module's state. current = self.current_state() nest.assert_same_structure(current, state) current_flat = nest.flatten(current) state_flat = nest.flatten(state) assign_ops = [ v.assign(s) for v, s in py_utils.eqzip(current_flat, state_flat) ] assign_ops.append(self.training_step.assign(state.training_step)) return tf.group(assign_ops, name="assign_state")
def assign_variables(targets, values): """Creates an Op that assigns a list of values to a target variables. Args: targets: list or structure of tf.Variable values: list or structure of tf.Tensor Returns: tf.Operation that performs the assignment """ return tf.group(*itertools.starmap( tf.assign, py_utils.eqzip(nest.flatten(targets), nest.flatten(values))), name="assign_variables")
def compute_update_and_next_state(self, rolling_features, grads_and_vars, training_step): new_vars = [] normalizer = utils.SecondMomentNormalizer(name="Normalizer") mod = snt.nets.MLP([self.hidden_size] * self.hidden_layer + [2], name="MLP") for (g, v), m, rms in py_utils.eqzip(grads_and_vars, rolling_features.ms, rolling_features.rms): def do_update(g, flat_v, m, rms): """Do a single tensor's update.""" flat_g = tf.reshape(g, [-1, 1]) rsqrt = tf.rsqrt(rms + 1e-6) norm_g = m * rsqrt inp = tf.concat([flat_g, norm_g, flat_v, m, rms, rsqrt], 1) inp = normalizer(inp, is_training=True) step = utils.tanh_embedding(training_step) stack_step = tf.tile(tf.reshape(step, [1, -1]), tf.stack([tf.shape(flat_g)[0], 1])) inp = tf.concat([inp, stack_step], axis=1) output = mod(inp) direction = output[:, 0:1] magnitude = output[:, 1:2] step = direction * tf.exp( magnitude * self.magnitude_rate) * self.step_multiplier new_flat_v = flat_v - step return new_flat_v, flat_v = tf.reshape(v, [-1, 1]) if isinstance(g, tf.IndexedSlices): new_flat_v, = utils.indexed_slices_apply_dense2( do_update, v.shape.as_list(), g, [flat_v, m, rms], 1) else: new_flat_v, = do_update(g, flat_v, m, rms) new_vars.append(tf.reshape(new_flat_v, v.shape)) return new_vars
def _nest_bimap(self, fn, data1, data2): data = py_utils.eqzip(nest.flatten(data1), nest.flatten(data2)) out = [fn(*a) for a in data] return nest.pack_sequence_as(data1, out)
def indexed_slices_apply_dense2(fn, var_shape, g_inp, dense_var_inp, n_outs): """Helper function to work with sparse tensors. dense_var_inp has the leading 2 dimensions collapsed forming shape [n_words * n_words_feat, n_feat] g_inp on the otherhand is [n_words, n_words_feat] var_shape is static and is [n_words, n_words_feat] Arguments: fn: (gradient: tf.Tensor, *var_args: tf.Tensor) -> [tf.Tensor] var_shape: list g_inp: tf.IndexedSlices dense_var_inp: tf.Tensor list. n_outs: int Returns: dense outputs """ grad_idx, grad_value = accumulate_sparse_gradients(g_inp) n_words, n_word_feat = var_shape args = [] for a_possibly_nest in dense_var_inp: def do_on_tensor(a): n_feat = a.shape.as_list()[1] n_active = tf.size(grad_idx) reshaped = tf.reshape(a, [n_words, n_word_feat, n_feat]) sub_reshaped = tf.gather(reshaped, grad_idx) return tf.reshape(sub_reshaped, [n_active * n_word_feat, n_feat]) args.append(nest.map_structure(do_on_tensor, a_possibly_nest)) returns = fn(grad_value, *args) def undo((full_val, sub_val)): """Undo the slices.""" if tf.shape(full_val).shape.as_list()[0] != 2: raise NotImplementedError( "TODO(lmetz) other than this is not implemented.") n_words, n_word_feat = var_shape _, n_feat = sub_val.shape.as_list() n_active = tf.size(grad_idx) shape = [n_active, n_word_feat * n_feat] in_shape_form = tf.reshape(sub_val, shape) new_shape = [n_words, n_word_feat * n_feat] mask_shape = [n_words, n_word_feat * n_feat] scattered = tf.scatter_nd(tf.reshape(tf.to_int32(grad_idx), [-1, 1]), in_shape_form, shape=new_shape) mask = tf.scatter_nd(tf.reshape(tf.to_int32(grad_idx), [-1, 1]), tf.ones_like(in_shape_form), shape=mask_shape) # put back into the flat format scattered = tf.reshape(scattered, [n_words * n_word_feat, n_feat]) mask = tf.reshape(mask, [n_words * n_word_feat, n_feat]) # this is the update part / fake scatter_update but with gradients. return full_val * (1 - mask) + scattered * mask dense_outs = [] for ret, dense_v in list(py_utils.eqzip(returns, dense_var_inp[0:n_outs])): flat_out = map( undo, py_utils.eqzip(nest.flatten(dense_v), nest.flatten(ret))) dense_outs.append(nest.pack_sequence_as(dense_v, flat_out)) return dense_outs
def train_op(self, ds_state): """Train with ES + Grads.""" perturbs = ds_state.perturbation rp_grads = ds_state.grads meta_loss = ds_state.meta_loss antith_meta_loss = ds_state.antith_meta_loss # convert the [bs] shaped tensors to something like [bs, 1, 1, ...]. broadcast_loss = [ tf.reshape(meta_loss, [-1] + [1] * (len(p.shape.as_list()) - 1)) for p in perturbs ] broadcast_antith_loss = [ tf.reshape(antith_meta_loss, [-1] + [1] * (len(p.shape.as_list()) - 1)) for p in perturbs ] # ES gradient: # f(x+s) * d/ds(log(p(s))) = f(x+s) * s / (std**2) # for antith: # (f(x+s) - f(x-s))*s/(2 * std**2) es_grads = [] for pos_loss, neg_loss, perturb in py_utils.eqzip( broadcast_loss, broadcast_antith_loss, perturbs): # this is the same as having 2 samples. es_grads.append( (pos_loss - neg_loss) * perturb / (self.custom_getter.std**2)) def mean_and_var(g): mean = tf.reduce_mean(g, axis=0, keep_dims=True) square_sum = tf.reduce_sum(tf.square((g - mean)), axis=0) var = square_sum / (g.shape.as_list()[0] - 1) return tf.squeeze(mean, 0), var + 1e-8 def combine(es, rp): """Do inverse variance rescaling.""" mean_es, var_es = mean_and_var(es) mean_rp, var_rp = mean_and_var(rp) es_var_inv = 1. / var_es rp_var_inv = 1. / var_rp den = es_var_inv + rp_var_inv combine_g = (mean_es * es_var_inv + mean_rp * rp_var_inv) / den weight_es = es_var_inv / den return combine_g, weight_es combine_grads, _ = zip( * [combine(es, rp) for es, rp in py_utils.eqzip(es_grads, rp_grads)]) grads_vars = py_utils.eqzip(combine_grads, self.learner.theta_mod.get_variables()) grads_vars = common.clip_grads_vars(grads_vars, self.gradient_clip_by_value) grads_vars = common.assert_grads_vars_not_nan(grads_vars) self._did_use_getter_on_all_variables() with tf.device(self.remote_device): train_op = self.meta_opt.apply_gradients(grads_vars) with tf.control_dependencies([train_op]): op = common.assert_post_update_not_nan(grads_vars) return tf.group(train_op, op, name="train_op")