def from_train_op(cls, train_op, loss, *, inputs=None, labels=None, metrics=None, updates=None, sess=None, dataset=None, tensor_with_value=None, session_config=None, model_dir=None): sess = TFOptimizer._get_or_create_session(sess) grads, variables = TFOptimizer._get_vars_grads_from_train_op(train_op) if dataset is None: dataset = TFOptimizer._get_dataset_from_loss(loss) _ = dataset.tensors # trigger create tensors if not available dataset_inputs = dataset._original_tensors if isinstance(dataset_inputs, tuple) and len(dataset_inputs) == 2: if inputs is None: inputs = dataset_inputs[0] if labels is None: labels = dataset_inputs[1] else: if inputs is None: inputs = dataset_inputs if labels is None: labels = [] inputs = nest.flatten(inputs) labels = nest.flatten(labels) return TFOptimizer._from_grads(loss=loss, sess=sess, inputs=inputs, labels=labels, grads=grads, variables=variables, dataset=dataset, metrics=metrics, tensor_with_value=tensor_with_value, optim_method=FakeOptimMethod(), session_config=session_config, updates=updates, model_dir=model_dir, train_op=train_op)
def from_train_op(cls, train_op, loss, metrics=None, updates=None, sess=None, dataset=None, tensor_with_value=None, session_config=None, model_dir=None): sess = TFOptimizer._get_or_create_session(sess) grads, variables = TFOptimizer._get_vars_grads_from_train_op(train_op) if dataset is None: dataset = TFOptimizer._get_dataset_from_loss(loss) inputs = nest.flatten(dataset._original_tensors) return TFOptimizer._from_grads(loss=loss, sess=sess, inputs=inputs, grads=grads, variables=variables, dataset=dataset, metrics=metrics, tensor_with_value=tensor_with_value, optim_method=FakeOptimMethod(), session_config=session_config, updates=updates, model_dir=model_dir, train_op=train_op)
def train(self, input_fn, end_trigger): with tf.Graph().as_default() as g: dataset = input_fn() generator_inputs = dataset.tensors[0] real_data = dataset.tensors[1] counter = tf.train.get_or_create_global_step() period = self._discriminator_steps + self._generator_steps is_discriminator_phase = tf.less(tf.mod(counter, period), self._discriminator_steps) with tf.variable_scope("Generator"): gen_data = self._call_fn_maybe_with_counter(self._generator_fn, counter, generator_inputs) with tf.variable_scope("Discriminator"): fake_d_outputs = self._call_fn_maybe_with_counter(self._discriminator_fn, counter, gen_data, generator_inputs) with tf.variable_scope("Discriminator", reuse=True): real_d_outputs = self._call_fn_maybe_with_counter(self._discriminator_fn, counter, real_data, generator_inputs) with tf.name_scope("Generator_loss"): generator_loss = self._call_fn_maybe_with_counter(self._generator_loss_fn, counter, fake_d_outputs) gen_reg_loss = tf.losses.get_regularization_loss("Generator") generator_loss = generator_loss + gen_reg_loss with tf.name_scope("Discriminator_loss"): discriminator_loss = self._call_fn_maybe_with_counter(self._discriminator_loss_fn, counter, real_d_outputs, fake_d_outputs) dis_reg_loss = tf.losses.get_regularization_loss("Discriminator") discriminator_loss = discriminator_loss + dis_reg_loss generator_variables = tf.trainable_variables("Generator") discriminator_variables = tf.trainable_variables("Discriminator") def run_gen_compute(): gen_grads_vars = self._gen_opt.compute_gradients(generator_loss, var_list=generator_variables) gen_grads = [grad for grad, var in gen_grads_vars] dis_grads = [tf.zeros_like(var) for var in discriminator_variables] return gen_grads + dis_grads def run_dis_compute(): dis_grads_vars = self._gen_opt.compute_gradients(discriminator_loss, var_list=discriminator_variables) dis_grads = [grad for grad, var in dis_grads_vars] gen_gards = [tf.zeros_like(var) for var in generator_variables] return gen_gards + dis_grads grads = tf.cond(is_discriminator_phase, run_dis_compute, run_gen_compute) grads_vars = list(zip(grads, generator_variables + discriminator_variables)) gen_grads_vars = grads_vars[:len(generator_variables)] dis_grads_vars = grads_vars[len(generator_variables):] grads = [grad for grad, var in grads_vars] _train_op = tf.cond(is_discriminator_phase, lambda: self._dis_opt.apply_gradients(dis_grads_vars), lambda: self._gen_opt.apply_gradients(gen_grads_vars)) variables = generator_variables + discriminator_variables loss = tf.cond(is_discriminator_phase, lambda: discriminator_loss, lambda: generator_loss) with tf.control_dependencies([_train_op]): increase_counter = tf.assign_add(counter, 1) with tf.control_dependencies([increase_counter]): train_op = tf.no_op() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() kpt = tf.train.latest_checkpoint(self.model_dir) if kpt is not None: saver.restore(sess, kpt) opt = TFOptimizer._from_grads(loss, sess, inputs=nest.flatten(dataset._original_tensors), labels=[], grads=grads, variables=variables, dataset=dataset, optim_method=FakeOptimMethod(), session_config=self._session_config, model_dir=os.path.join(self.model_dir, "tmp"), train_op=train_op) opt.optimize(end_trigger) saver = tf.train.Saver() saver.save(sess, self.checkpoint_path, global_step=counter)