def time_step(t, fx_array, fx_array_opt, lr_optimizee, x, state): """While loop body.""" x_next = list(x) state_next = [] ratio = [] with tf.name_scope("fx"): fx = [ util._make_with_custom_variables( a, x[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(make_loss.values(), range( len(num_var) - 1)) ] with tf.name_scope("fx_sum"): fxsum = sum(fx[a] for a in range(len(fx))) fx_array = fx_array.write(t, fxsum) with tf.name_scope("fx_opt"): fxopt = fx[0] fx_array_opt = fx_array_opt.write(t, fxopt) with tf.name_scope("dx"): for subset, key, s_i in zip(subsets, net_keys, state): x_i = [x[j] for j in subset] deltas, s_i_next, ratio_i = update(nets[key], fx, x_i, s_i, subset) for idx, j in enumerate(subset): x_next[j] += deltas[idx] state_next.append(s_i_next) ratio.append(ratio_i) with tf.name_scope("lr_opt"): lr_optimizee = lr_optimizee.write(t, sum(ratio) / len(ratio)) with tf.name_scope("t_next"): t_next = t + 1 return t_next, fx_array, fx_array_opt, lr_optimizee, x_next, state_next
def time_step(t, f_array, f_array_opt, x, state_c, state_h): losstot = [ util._make_with_custom_variables( a, x[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(dictloss.values(), range(len(num_var) - 1)) ] with tf.name_scope('Unroll_Optimizee_loss'): fx_opt = losstot[0] f_array_opt = f_array_opt.write(t, fx_opt) with tf.name_scope('Unroll_loss_t'): fx = sum(losstot[a] for a in range(len(losstot))) f_array = f_array.write(t, fx) with tf.name_scope('Unroll_delta_state_update'): delta, s_c_out, s_h_out = update_state(losstot, x, state_c, state_h) with tf.name_scope('Unroll_Optimizee_update'): x_new = [ x_n + Preprocess.update_tanh(d) for x_n, d in zip(x, delta) ] t_new = t + 1 return t_new, f_array, f_array_opt, x_new, s_c_out, s_h_out
def metaoptimizer(self, dictloss): hidden_size = self._config['hidden_size'] num_layer = self._config['num_layer'] unroll_nn = self._config['unroll_nn'] lr = self._config['lr'] with tf.device('/device:GPU:0'): input_var = [util._get_variables(a)[0] for a in dictloss.values()] num_var = nest.flatten([0, [len(a) for a in input_var]]) opt_var = nest.flatten(input_var) shapes = [K.get_variable_shape(p) for p in opt_var] with tf.variable_scope("softmax", reuse=tf.AUTO_REUSE): softmax_w = tf.get_variable("softmax_w", shape=[hidden_size, 1], dtype=tf.float32) softmax_b = tf.get_variable("softmax_b", shape=[1], dtype=tf.float32) with tf.name_scope('states'): state_c = [[] for _ in range(len(opt_var))] state_h = [[] for _ in range(len(opt_var))] for i in range(len(opt_var)): n_param = int(np.prod(shapes[i])) state_c[i] = [ tf.Variable(tf.zeros([n_param, hidden_size]), dtype=tf.float32, name="c_in", trainable=False) for _ in range(num_layer) ] state_h[i] = [ tf.Variable(tf.zeros([n_param, hidden_size]), dtype=tf.float32, name="h_in", trainable=False) for _ in range(num_layer) ] def update_state(losstot, x, state_c, state_h): with tf.name_scope("gradients"): shapes = [K.get_variable_shape(p) for p in x] grads = nest.flatten([ K.gradients(losstot[a], x[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(range(len(losstot)), range( len(num_var) - 1)) ]) grads = [tf.stop_gradient(g) for g in grads] with tf.variable_scope('MetaNetwork'): cell_count = 0 delta = [[] for _ in range(len(grads))] S_C_out = [[] for _ in range(len(opt_var))] S_H_out = [[] for _ in range(len(opt_var))] for i in range(len(grads)): g = grads[i] n_param = int(np.prod(shapes[i])) flat_g = tf.reshape(g, [n_param, -1]) flat_g_mod = tf.reshape(Preprocess.log_encode(flat_g), [n_param, -1]) rnn_new_c = [[] for _ in range(num_layer)] rnn_new_h = [[] for _ in range(num_layer)] # Apply RNN cell for each parameter with tf.variable_scope("RNN"): rnn_outputs = [] rnn_state_c = [[] for _ in range(num_layer)] rnn_state_h = [[] for _ in range(num_layer)] for ii in range(n_param): state_in = [ tf.contrib.rnn.LSTMStateTuple( state_c[i][j][ii:ii + 1], state_h[i][j][ii:ii + 1]) for j in range(num_layer) ] rnn_cell = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.LSTMCell(num_units=hidden_size, reuse=cell_count > 0) for _ in range(num_layer) ]) cell_count += 1 print(ii) # Individual update with individual state but global cell params rnn_out_all, state_out = rnn_cell( flat_g_mod[ii:ii + 1, :], state_in) rnn_out = tf.add(tf.matmul(rnn_out_all, softmax_w), softmax_b) rnn_outputs.append(rnn_out) for j in range(num_layer): rnn_state_c[j].append(state_out[j].c) rnn_state_h[j].append(state_out[j].h) # Form output as tensor rnn_outputs = tf.reshape(tf.stack(rnn_outputs, axis=1), g.get_shape()) for j in range(num_layer): rnn_new_c[j] = tf.reshape( tf.stack(rnn_state_c[j], axis=1), (n_param, hidden_size)) rnn_new_h[j] = tf.reshape( tf.stack(rnn_state_h[j], axis=1), (n_param, hidden_size)) # Dense output from state delta[i] = rnn_outputs S_C_out[i] = rnn_new_c S_H_out[i] = rnn_new_h return delta, S_C_out, S_H_out def time_step(t, f_array, f_array_opt, x, state_c, state_h): losstot = [ util._make_with_custom_variables( a, x[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(dictloss.values(), range(len(num_var) - 1)) ] with tf.name_scope('Unroll_Optimizee_loss'): fx_opt = losstot[0] f_array_opt = f_array_opt.write(t, fx_opt) with tf.name_scope('Unroll_loss_t'): fx = sum(losstot[a] for a in range(len(losstot))) f_array = f_array.write(t, fx) with tf.name_scope('Unroll_delta_state_update'): delta, s_c_out, s_h_out = update_state(losstot, x, state_c, state_h) with tf.name_scope('Unroll_Optimizee_update'): x_new = [ x_n + Preprocess.update_tanh(d) for x_n, d in zip(x, delta) ] t_new = t + 1 return t_new, f_array, f_array_opt, x_new, s_c_out, s_h_out with tf.device('/device:GPU:0'): fx_array = tf.TensorArray(tf.float32, size=unroll_nn, clear_after_read=False) fx_array_opt = tf.TensorArray(tf.float32, size=unroll_nn, clear_after_read=False) _, fx_array, fx_array_opt, x_final, S_C, S_H = tf.while_loop( cond=lambda t, *_: t < unroll_nn - 1, body=time_step, loop_vars=(0, fx_array, fx_array_opt, opt_var, state_c, state_h), parallel_iterations=1, swap_memory=True, name="unroll") finaltotloss = [ util._make_with_custom_variables( a, x_final[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(dictloss.values(), range(len(num_var) - 1)) ] with tf.name_scope('Unroll_loss_period'): fx_final = sum(finaltotloss[a] for a in range(len(finaltotloss))) fx_array = fx_array.write(unroll_nn - 1, fx_final) with tf.name_scope('Final_Optimizee_loss'): fx_final_opt = finaltotloss[0] fx_array_opt = fx_array_opt.write(unroll_nn - 1, fx_final_opt) arrayf = fx_array_opt.stack() with tf.name_scope('Metaloss'): loss_optimizer = tf.reduce_sum(fx_array.stack()) with tf.name_scope('MetaOpt'): optimizer = tf.train.AdamOptimizer(lr) with tf.name_scope('Meta_update'): step = optimizer.minimize(loss_optimizer) with tf.name_scope('state_optimizee_var'): variables = (nest.flatten(state_c) + nest.flatten(state_h) + opt_var) with tf.name_scope('state_reset'): reset = [ tf.variables_initializer(variables), fx_array.close(), fx_array_opt.close() ] with tf.name_scope('Optimizee_update'): update = (nest.flatten([ tf.assign(r, v) for r, v in zip(opt_var, x_final) ]) + (nest.flatten([ tf.assign(r, v) for r, v in zip(nest.flatten(state_c), nest.flatten(S_C)) ])) + (nest.flatten([ tf.assign(r, v) for r, v in zip(nest.flatten(state_h), nest.flatten(S_H)) ]))) return step, loss_optimizer, update, reset, fx_final, fx_final_opt, arrayf, x_final
def meta_loss(self, make_loss, len_unroll, net_assignments=None, second_derivatives=False): """Returns an operator computing the meta-loss. Args: make_loss: Callable which returns the optimizee loss; note that this should create its ops in the default graph. len_unroll: Number of steps to unroll. net_assignments: variable to optimizer mapping. If not None, it should be a list of (k, names) tuples, where k is a valid key in the kwargs passed at at construction time and names is a list of variable names. second_derivatives: Use second derivatives (default is false). Returns: namedtuple containing (loss, update, reset, fx, x) """ # Construct an instance of the problem only to grab the variables. This # loss will never be evaluated. x = [] constants = [] for a in make_loss.values(): item1, item2 = util._get_variables(a) x.append(item1) constants.append(item2) num_var = nest.flatten([0, [len(a) for a in x]]) var_num = np.cumsum(num_var) x = nest.flatten(x) constants = nest.flatten(constants) print("Optimizee variables") print([op.name for op in x]) print("Problem variables") print([op.name for op in constants]) # Create the optimizer networks and find the subsets of variables to assign # to each optimizer. nets, net_keys, subsets = util._make_nets(x, self._config, net_assignments) # Store the networks so we can save them later. self._nets = nets # Create hidden state for each subset of variables. state = [] with tf.name_scope("states"): for i, (subset, key) in enumerate(zip(subsets, net_keys)): net = nets[key] with tf.name_scope("state_{}".format(i)): state.append( util._nested_variable([ net.initial_state_for_inputs(x[j], dtype=tf.float32) for j in subset ], name="state", trainable=False)) def update(net, fx, x, state, subset): """Parameter and RNN state update.""" with tf.name_scope("gradients"): if len(subset) == sum(num_var): gradients = nest.flatten([ tf.gradients(fx[a], x[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(range(len(fx)), range( len(num_var) - 1)) ]) else: bin_num = np.digitize(subset, var_num) - 1 if np.std(bin_num) == 0 or len(bin_num) == 1: gradients = nest.flatten( [tf.gradients(fx[bin_num[0]], x)]) else: gradients = nest.flatten([ tf.gradients(fx[a], x[b]) for a, b in zip(bin_num, range(len(x))) ]) if not second_derivatives: gradients = [tf.stop_gradient(g) for g in gradients] with tf.name_scope("deltas"): deltas, state_next = zip( *[net(g, s) for g, s in zip(gradients, state)]) state_next = list(state_next) ratio = sum([ tf.reduce_mean(tf.div(d, g)) for d, g in zip(deltas, gradients) ]) / len(gradients) return deltas, state_next, ratio def time_step(t, fx_array, fx_array_opt, lr_optimizee, x, state): """While loop body.""" x_next = list(x) state_next = [] ratio = [] with tf.name_scope("fx"): fx = [ util._make_with_custom_variables( a, x[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(make_loss.values(), range( len(num_var) - 1)) ] with tf.name_scope("fx_sum"): fxsum = sum(fx[a] for a in range(len(fx))) fx_array = fx_array.write(t, fxsum) with tf.name_scope("fx_opt"): fxopt = fx[0] fx_array_opt = fx_array_opt.write(t, fxopt) with tf.name_scope("dx"): for subset, key, s_i in zip(subsets, net_keys, state): x_i = [x[j] for j in subset] deltas, s_i_next, ratio_i = update(nets[key], fx, x_i, s_i, subset) for idx, j in enumerate(subset): x_next[j] += deltas[idx] state_next.append(s_i_next) ratio.append(ratio_i) with tf.name_scope("lr_opt"): lr_optimizee = lr_optimizee.write(t, sum(ratio) / len(ratio)) with tf.name_scope("t_next"): t_next = t + 1 return t_next, fx_array, fx_array_opt, lr_optimizee, x_next, state_next # Define the while loop. fx_array = tf.TensorArray(tf.float32, size=len_unroll, clear_after_read=False) fx_array_opt = tf.TensorArray(tf.float32, size=len_unroll, clear_after_read=False) lr_optimizee = tf.TensorArray(tf.float32, size=len_unroll - 1, clear_after_read=False) _, fx_array, fx_array_opt, lr_optimizee, x_final, s_final = tf.while_loop( cond=lambda t, *_: t < len_unroll - 1, body=time_step, loop_vars=(0, fx_array, fx_array_opt, lr_optimizee, x, state), parallel_iterations=1, swap_memory=True, name="unroll") with tf.name_scope("fx"): fx_final = [ util._make_with_custom_variables( a, x_final[num_var[b]:num_var[b] + num_var[b + 1]]) for a, b in zip(make_loss.values(), range(len(num_var) - 1)) ] with tf.name_scope("fx_sum"): fxsum = sum(fx_final[a] for a in range(len(fx_final))) fx_array = fx_array.write(len_unroll - 1, fxsum) with tf.name_scope("fx_opt"): fxopt = fx_final[0] fx_array_opt = fx_array_opt.write(len_unroll - 1, fxopt) farray = fx_array_opt.stack() with tf.name_scope("lr_opt"): lr_opt = lr_optimizee.stack() loss = tf.reduce_sum(fx_array.stack(), name="loss") # Reset the state; should be called at the beginning of an epoch. with tf.name_scope("reset"): variables = (nest.flatten(state) + x + constants) # Empty array as part of the reset process. reset = [ tf.variables_initializer(variables), fx_array.close(), fx_array_opt.close(), lr_optimizee.close() ] # Operator to update the parameters and the RNN state after our loop, but # during an epoch. with tf.name_scope("update"): update = (nest.flatten(util._nested_assign(x, x_final)) + nest.flatten(util._nested_assign(state, s_final))) # Log internal variables. for k, net in nets.items(): print("Optimizer '{}' variables".format(k)) print([op.name for op in snt.get_variables_in_module(net)]) return MetaLoss(loss, update, reset, fxopt, farray, lr_opt, x_final)