def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss): # == Create class with static fields and methods class m(object): pass m.sess = sess m.feeds = feeds m.lr = lr # === Loss and optimizer loss_train, stats_train = f_loss(train_iterator, True) all_params = tf.trainable_variables() if hps.gradient_checkpointing == 1: from memory_saving_gradients import gradients gs = gradients(loss_train, all_params) else: gs = tf.gradients(loss_train, all_params) optimizer = { 'adam': optim.adam, 'adamax': optim.adamax, 'adam2': optim.adam2 }[hps.optimizer] train_op, polyak_swap_op, ema = optimizer(all_params, gs, alpha=lr, hps=hps) if hps.direct_iterator: m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1] else: def _train(_lr): _x, _y = train_iterator() return sess.run([train_op, stats_train], { feeds['x']: _x, feeds['y']: _y, lr: _lr })[1] m.train = _train m.polyak_swap = lambda: sess.run(polyak_swap_op) # === Testing loss_test, stats_test = f_loss(test_iterator, False, reuse=True) if hps.direct_iterator: m.test = lambda: sess.run(stats_test) else: def _test(): _x, _y = test_iterator() return sess.run(stats_test, {feeds['x']: _x, feeds['y']: _y}) m.test = _test # === Saving and restoring saver = tf.train.Saver() saver_ema = tf.train.Saver(ema.variables_to_restore()) m.save_ema = lambda path: saver_ema.save( sess, path, write_meta_graph=False) m.save = lambda path: saver.save(sess, path, write_meta_graph=False) m.restore = lambda path: saver.restore(sess, path) # === Initialize the parameters if hps.restore_path != '': m.restore(hps.restore_path) else: with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True): results_init = f_loss(None, True, reuse=True) sess.run(tf.global_variables_initializer()) sess.run(results_init, { feeds['x']: data_init['x'], feeds['y']: data_init['y'] }) # sess.run(hvd.broadcast_global_variables(0)) return m
def abstract_model_xy(sess, hps, feeds, train_iterators, test_iterators, data_inits, lr, f_loss): # == Create class with static fields and methods class m(object): pass m.sess = sess m.feeds = feeds m.lr = lr # === Loss and optimizer if hps.joint_train: (loss_train_A, stats_train_A, eps_flatten_A, loss_train_B, stats_train_B, eps_flatten_B) \ = f_loss(train_iterators, is_training=True) else: (loss_train_A, stats_train_A, loss_train_B, stats_train_B) \ = f_loss(train_iterators, is_training=True) all_params = tf.trainable_variables() # Get train data op def get_train_data(): x_A, y_A = train_iterators['A']() x_B, y_B = train_iterators['B']() return x_A, y_A, x_B, y_B m.get_train_data = get_train_data # A with tf.variable_scope('optim_A'): params_A = [param for param in all_params if 'A/' in param.name] if hps.gradient_checkpointing == 1: from memory_saving_gradients import gradients gs_A = gradients(loss_train_A, params_A) else: gs_A = tf.gradients(loss_train_A, params_A) m.optimizer_A = optim.Optimizer() train_op_A, polyak_swap_op_A, ema_A = m.optimizer_A.adamax( params_A, gs_A, alpha=lr, hps=hps) if hps.direct_iterator: m.train_A = lambda _lr: sess.run( [train_op_A, stats_train_A], {lr: _lr})[1] else: def _train_A(_lr, _x_A, _y_A, _x_B, _y_B): return sess.run([train_op_A, stats_train_A], {feeds['x_A']: _x_A, feeds['y_A']: _y_A, feeds['x_B']: _x_B, feeds['y_B']: _y_B, lr: _lr})[1] m.train_A = _train_A m.polyak_swap_A = lambda: sess.run(polyak_swap_op_A) # B with tf.variable_scope('optim_B'): params_B = [param for param in all_params if 'B/' in param.name] if hps.gradient_checkpointing == 1: from memory_saving_gradients import gradients gs_B = gradients(loss_train_B, params_B) else: gs_B = tf.gradients(loss_train_B, params_B) m.optimizer_B = optim.Optimizer() train_op_B, polyak_swap_op_B, ema_B = m.optimizer_B.adamax( params_B, gs_B, alpha=lr, hps=hps) if hps.direct_iterator: m.train_B = lambda _lr: sess.run( [train_op_B, stats_train_B], {lr: _lr})[1] else: def _train_B(_lr, _x_A, _y_A, _x_B, _y_B): return sess.run([train_op_B, stats_train_B], {feeds['x_A']: _x_A, feeds['y_A']: _y_A, feeds['x_B']: _x_B, feeds['y_B']: _y_B, lr: _lr})[1] m.train_B = _train_B m.polyak_swap_B = lambda: sess.run(polyak_swap_op_B) def _train(_lr, _x_A, _y_A, _x_B, _y_B): return sess.run([train_op_A, train_op_B, stats_train_A, stats_train_B], {feeds['x_A']: _x_A, feeds['y_A']: _y_A, feeds['x_B']: _x_B, feeds['y_B']: _y_B, lr: _lr})[-2:] m.train = _train # === Testing loss_test_A, stats_test_A, loss_test_B, stats_test_B = f_loss( test_iterators, False, reuse=True) if hps.direct_iterator: m.test_A = lambda: sess.run(stats_test_A) m.test_B = lambda: sess.run(stats_test_B) else: # Get test data op def get_test_data(): x_A, y_A = test_iterators['A']() x_B, y_B = test_iterators['B']() return x_A, y_A, x_B, y_B m.get_test_data = get_test_data def _test_A(_x_A, _y_A, _x_B, _y_B): return sess.run(stats_test_A, {feeds['x_A']: _x_A, feeds['y_A']: _y_A, feeds['x_B']: _x_B, feeds['y_B']: _y_B}) def _test_B(_x_A, _y_A, _x_B, _y_B): return sess.run(stats_test_B, {feeds['x_A']: _x_A, feeds['y_A']: _y_A, feeds['x_B']: _x_B, feeds['y_B']: _y_B}) m.test_A = _test_A m.test_B = _test_B # === Saving and restoring with tf.variable_scope('saver_A'): saver_A = tf.train.Saver() saver_ema_A = tf.train.Saver(ema_A.variables_to_restore()) m.save_ema_A = lambda path_A: saver_ema_A.save( sess, path_A, write_meta_graph=False) m.save_A = lambda path_A: saver_A.save( sess, path_A, write_meta_graph=False) m.restore_A = lambda path_A: saver_A.restore(sess, path_A) with tf.variable_scope('saver_B'): saver_B = tf.train.Saver() saver_ema_B = tf.train.Saver(ema_B.variables_to_restore()) m.save_ema_B = lambda path_B: saver_ema_B.save( sess, path_B, write_meta_graph=False) m.save_B = lambda path_B: saver_B.save( sess, path_B, write_meta_graph=False) m.restore_B = lambda path_B: saver_B.restore(sess, path_B) print("After saver") # === Initialize the parameters if hps.restore_path_A != '': m.restore_A(hps.restore_path_A) if hps.restore_path_B != '': m.restore_B(hps.restore_path_B) if hps.restore_path_A == '' and hps.restore_path_B == '': with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True): results_init = f_loss(None, False, reuse=True, init=True) all_params = tf.global_variables() params_A = [param for param in all_params if 'A/' in param.name] params_B = [param for param in all_params if 'B/' in param.name] sess.run(tf.variables_initializer(params_A)) sess.run(tf.variables_initializer(params_B)) feeds_dict = {feeds['x_A']: data_inits['A']['x'], feeds['y_A']: data_inits['A']['y'], feeds['x_B']: data_inits['B']['x'], feeds['y_B']: data_inits['B']['y']} sess.run(results_init, feeds_dict) sess.run(hvd.broadcast_global_variables(0)) return m
def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss): # == Create class with static fields and methods class m(object): pass m.sess = sess m.feeds = feeds m.lr = lr # === Loss and optimizer loss_train, stats_train = f_loss(train_iterator, True) all_params = tf.trainable_variables() if hps.gradient_checkpointing == 1: from memory_saving_gradients import gradients gs = gradients(loss_train, all_params) else: gs = tf.gradients(loss_train, all_params) optimizer = {'adam': optim.adam, 'adamax': optim.adamax, 'adam2': optim.adam2}[hps.optimizer] train_op, polyak_swap_op, ema = optimizer( all_params, gs, alpha=lr, hps=hps) if hps.direct_iterator: m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1] else: def _train(_lr): _x, _y = train_iterator() return sess.run([train_op, stats_train], {feeds['x']: _x, feeds['y']: _y, lr: _lr})[1] m.train = _train m.polyak_swap = lambda: sess.run(polyak_swap_op) # === Testing loss_test, stats_test = f_loss(test_iterator, False, reuse=True) if hps.direct_iterator: m.test = lambda: sess.run(stats_test) else: def _test(): _x, _y = test_iterator() return sess.run(stats_test, {feeds['x']: _x, feeds['y']: _y}) m.test = _test # === Saving and restoring saver = tf.train.Saver() saver_ema = tf.train.Saver(ema.variables_to_restore()) m.save_ema = lambda path: saver_ema.save( sess, path, write_meta_graph=False) m.save = lambda path: saver.save(sess, path, write_meta_graph=False) m.restore = lambda path: saver.restore(sess, path) # === Initialize the parameters if hps.restore_path != '': m.restore(hps.restore_path) else: with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True): results_init = f_loss(None, True, reuse=True) sess.run(tf.global_variables_initializer()) sess.run(results_init, {feeds['x']: data_init['x'], feeds['y']: data_init['y']}) sess.run(hvd.broadcast_global_variables(0)) return m