def _feature_transformer(input_data): """Feature transformer at the end of training.""" initial_variables = base_model.get_variables() replacement = collections.OrderedDict( utils.eqzip(initial_variables, new_variables)) with variable_replace.variable_replace(replacement): features = base_model(input_data) return features
def assign_state(self, base_model, next_state): var_ups = [ v.assign(nv) for v, nv in utils.eqzip(base_model.get_variables(), next_state.variables) ] opt_ups = self.opt.assign_state(next_state.opt_state) return tf.group(opt_ups, *var_ups)
def compute_updates(self, xs, gs, learning_rates, state): new_vars = [] for x, g, lr in utils.eqzip(xs, gs, learning_rates): if lr is None: lr = self.learning_rate if g is not None: new_vars.append((x * (1 - lr) - g * lr)) else: new_vars.append(x) return new_vars, state
def compute_next_state(self, outputs, meta_opt, previous_state): zs = outputs.zs xs = outputs.xs batch = outputs.batch mods = outputs.mods backward_mods = outputs.backward_mods variables = self.get_variables() rev_mods = mods[::-1] rev_backward_mods = backward_mods[::-1] rev_xs = xs[::-1] rev_zs = zs[::-1] + [None] to_top = xs[-1] # variables that change in the loop hs = [] d = meta_opt.compute_top_delta(to_top) # [bs x 32 x delta_channels] iterator = utils.eqzip(rev_backward_mods + [None], rev_mods + [None], [None] + rev_mods, rev_xs, rev_zs) for (backward_mod, lower_mod, upper_mod, x, z) in iterator: w_bot = None if not lower_mod is None: w_bot = previous_state.variables[variables.index(lower_mod.w)] w_top = None if not upper_mod is None: w_top = previous_state.variables[variables.index(upper_mod.w)] backward_w = None if backward_mod is not None: backward_w = previous_state.variables[variables.index( backward_mod.w)] if lower_mod is not None: bias = previous_state.variables[variables.index(lower_mod.b)] else: bias = tf.zeros([x.shape[1]]) h, d = self.compute_next_h_d(meta_opt=meta_opt, w_bot=w_bot, w_top=w_top, bias=bias, backward_w=backward_w, x=x, z=z, d=d) hs.append(h) w_forward_var_idx = [variables.index(mod.w) for mod in rev_mods] w_backward_var_idx = [ variables.index(mod.w) for mod in rev_backward_mods ] b_var_idx = [variables.index(mod.b) for mod in rev_mods] # storage location for outputs of below loop grads = [None for _ in previous_state.variables] # over-ride learning rate for perturbation variables learning_rate = [None for _ in previous_state.variables] # This is a map -- no state is shared cross loop for l_idx, w_forward_idx, w_backward_idx, b_idx, upper_h, lower_h, lower_x, upper_x in utils.eqzip( range(len(w_forward_var_idx)), w_forward_var_idx, w_backward_var_idx, b_var_idx, hs[:-1], hs[1:], xs[::-1][1:], xs[::-1][:-1]): b_base = previous_state.variables[b_idx] change_w_forward, change_b = self.weight_change_for_layer( meta_opt=meta_opt, l_idx=l_idx, w_base=previous_state.variables[w_forward_idx], b_base=b_base, upper_h=upper_h, lower_h=lower_h, upper_x=upper_x, lower_x=lower_x, prefix='forward', include_bias=True) if self.identical_updates: change_w_backward = change_w_forward else: change_w_backward = self.weight_change_for_layer( meta_opt=meta_opt, l_idx=l_idx, w_base=previous_state.variables[w_backward_idx], b_base=b_base, upper_h=upper_h, lower_h=lower_h, upper_x=upper_x, lower_x=lower_x, prefix='backward', include_bias=False) grads[w_forward_idx] = change_w_forward grads[w_backward_idx] = change_w_backward grads[b_idx] = change_b cur_transformer = common.transformer_at_state(self, previous_state.variables) next_state = meta_opt.compute_next_state( grads, learning_rate=learning_rate, cur_state=previous_state, cur_transformer=lambda x: cur_transformer(x)[0]) return next_state
def compute_next_state(self, outputs, meta_opt, previous_state): zs = outputs.zs xs = outputs.xs batch = outputs.batch mods = outputs.mods backward_mods = outputs.backward_mods variables = self.get_variables() rev_mods = mods[::-1] rev_backward_mods = backward_mods[::-1] rev_xs = xs[::-1] rev_zs = zs[::-1] + [None] to_top = xs[-1] # variables that change in the loop hs = [] d = meta_opt.compute_top_delta(to_top) # [bs x 32 x delta_channels] iterator = utils.eqzip(rev_backward_mods + [None], rev_mods + [None], [None] + rev_mods, rev_xs, rev_zs) for (backward_mod, lower_mod, upper_mod, x, z) in iterator: w_bot = None if not lower_mod is None: w_bot = previous_state.variables[variables.index(lower_mod.w)] w_top = None if not upper_mod is None: w_top = previous_state.variables[variables.index(upper_mod.w)] backward_w = None if backward_mod is not None: backward_w = previous_state.variables[variables.index(backward_mod.w)] if lower_mod is not None: bias = previous_state.variables[variables.index(lower_mod.b)] else: bias = tf.zeros([x.shape[1]]) h, d = self.compute_next_h_d( meta_opt=meta_opt, w_bot=w_bot, w_top=w_top, bias=bias, backward_w=backward_w, x=x, z=z, d=d) hs.append(h) w_forward_var_idx = [variables.index(mod.w) for mod in rev_mods] w_backward_var_idx = [variables.index(mod.w) for mod in rev_backward_mods] b_var_idx = [variables.index(mod.b) for mod in rev_mods] # storage location for outputs of below loop grads = [None for _ in previous_state.variables] # over-ride learning rate for perturbation variables learning_rate = [None for _ in previous_state.variables] # This is a map -- no state is shared cross loop for l_idx, w_forward_idx, w_backward_idx, b_idx, upper_h, lower_h, lower_x, upper_x in utils.eqzip( range(len(w_forward_var_idx)), w_forward_var_idx, w_backward_var_idx, b_var_idx, hs[:-1], hs[1:], xs[::-1][1:], xs[::-1][:-1]): b_base = previous_state.variables[b_idx] change_w_forward, change_b = self.weight_change_for_layer( meta_opt=meta_opt, l_idx=l_idx, w_base=previous_state.variables[w_forward_idx], b_base=b_base, upper_h=upper_h, lower_h=lower_h, upper_x=upper_x, lower_x=lower_x, prefix='forward', include_bias=True) if self.identical_updates: change_w_backward = change_w_forward else: change_w_backward = self.weight_change_for_layer( meta_opt=meta_opt, l_idx=l_idx, w_base=previous_state.variables[w_backward_idx], b_base=b_base, upper_h=upper_h, lower_h=lower_h, upper_x=upper_x, lower_x=lower_x, prefix='backward', include_bias=False) grads[w_forward_idx] = change_w_forward grads[w_backward_idx] = change_w_backward grads[b_idx] = change_b cur_transformer = common.transformer_at_state(self, previous_state.variables) next_state = meta_opt.compute_next_state( grads, learning_rate=learning_rate, cur_state=previous_state, cur_transformer=lambda x: cur_transformer(x)[0]) return next_state