Exemplo n.º 1
0
    def from_train_op(cls, train_op, loss, *, inputs=None, labels=None, metrics=None, updates=None,
                      sess=None, dataset=None, tensor_with_value=None, session_config=None,
                      model_dir=None):
        sess = TFOptimizer._get_or_create_session(sess)
        grads, variables = TFOptimizer._get_vars_grads_from_train_op(train_op)
        if dataset is None:
            dataset = TFOptimizer._get_dataset_from_loss(loss)
        _ = dataset.tensors  # trigger create tensors if not available
        dataset_inputs = dataset._original_tensors
        if isinstance(dataset_inputs, tuple) and len(dataset_inputs) == 2:
            if inputs is None:
                inputs = dataset_inputs[0]

            if labels is None:
                labels = dataset_inputs[1]
        else:
            if inputs is None:
                inputs = dataset_inputs

            if labels is None:
                labels = []

        inputs = nest.flatten(inputs)
        labels = nest.flatten(labels)
        return TFOptimizer._from_grads(loss=loss, sess=sess, inputs=inputs, labels=labels,
                                       grads=grads,
                                       variables=variables, dataset=dataset, metrics=metrics,
                                       tensor_with_value=tensor_with_value,
                                       optim_method=FakeOptimMethod(),
                                       session_config=session_config, updates=updates,
                                       model_dir=model_dir, train_op=train_op)
Exemplo n.º 2
0
 def from_train_op(cls, train_op, loss, metrics=None, updates=None, sess=None, dataset=None,
                   tensor_with_value=None, session_config=None, model_dir=None):
     sess = TFOptimizer._get_or_create_session(sess)
     grads, variables = TFOptimizer._get_vars_grads_from_train_op(train_op)
     if dataset is None:
         dataset = TFOptimizer._get_dataset_from_loss(loss)
     inputs = nest.flatten(dataset._original_tensors)
     return TFOptimizer._from_grads(loss=loss, sess=sess, inputs=inputs, grads=grads,
                                    variables=variables, dataset=dataset, metrics=metrics,
                                    tensor_with_value=tensor_with_value,
                                    optim_method=FakeOptimMethod(),
                                    session_config=session_config, updates=updates,
                                    model_dir=model_dir, train_op=train_op)
Exemplo n.º 3
0
    def train(self, input_fn, end_trigger):

        with tf.Graph().as_default() as g:

            dataset = input_fn()

            generator_inputs = dataset.tensors[0]
            real_data = dataset.tensors[1]

            counter = tf.train.get_or_create_global_step()

            period = self._discriminator_steps + self._generator_steps

            is_discriminator_phase = tf.less(tf.mod(counter, period), self._discriminator_steps)

            with tf.variable_scope("Generator"):
                gen_data = self._call_fn_maybe_with_counter(self._generator_fn, counter,
                                                            generator_inputs)

            with tf.variable_scope("Discriminator"):
                fake_d_outputs = self._call_fn_maybe_with_counter(self._discriminator_fn,
                                                                  counter,
                                                                  gen_data, generator_inputs)

            with tf.variable_scope("Discriminator", reuse=True):
                real_d_outputs = self._call_fn_maybe_with_counter(self._discriminator_fn,
                                                                  counter,
                                                                  real_data, generator_inputs)

            with tf.name_scope("Generator_loss"):
                generator_loss = self._call_fn_maybe_with_counter(self._generator_loss_fn,
                                                                  counter,
                                                                  fake_d_outputs)
                gen_reg_loss = tf.losses.get_regularization_loss("Generator")

                generator_loss = generator_loss + gen_reg_loss

            with tf.name_scope("Discriminator_loss"):
                discriminator_loss = self._call_fn_maybe_with_counter(self._discriminator_loss_fn,
                                                                      counter,
                                                                      real_d_outputs,
                                                                      fake_d_outputs)
                dis_reg_loss = tf.losses.get_regularization_loss("Discriminator")
                discriminator_loss = discriminator_loss + dis_reg_loss

            generator_variables = tf.trainable_variables("Generator")
            discriminator_variables = tf.trainable_variables("Discriminator")

            def run_gen_compute():
                gen_grads_vars = self._gen_opt.compute_gradients(generator_loss,
                                                                 var_list=generator_variables)
                gen_grads = [grad for grad, var in gen_grads_vars]
                dis_grads = [tf.zeros_like(var) for var in discriminator_variables]

                return gen_grads + dis_grads

            def run_dis_compute():
                dis_grads_vars = self._gen_opt.compute_gradients(discriminator_loss,
                                                                 var_list=discriminator_variables)
                dis_grads = [grad for grad, var in dis_grads_vars]
                gen_gards = [tf.zeros_like(var) for var in generator_variables]
                return gen_gards + dis_grads

            grads = tf.cond(is_discriminator_phase, run_dis_compute, run_gen_compute)

            grads_vars = list(zip(grads, generator_variables + discriminator_variables))

            gen_grads_vars = grads_vars[:len(generator_variables)]
            dis_grads_vars = grads_vars[len(generator_variables):]

            grads = [grad for grad, var in grads_vars]

            _train_op = tf.cond(is_discriminator_phase,
                                lambda: self._dis_opt.apply_gradients(dis_grads_vars),
                                lambda: self._gen_opt.apply_gradients(gen_grads_vars))

            variables = generator_variables + discriminator_variables

            loss = tf.cond(is_discriminator_phase,
                           lambda: discriminator_loss,
                           lambda: generator_loss)

            with tf.control_dependencies([_train_op]):
                increase_counter = tf.assign_add(counter, 1)

            with tf.control_dependencies([increase_counter]):
                train_op = tf.no_op()

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                saver = tf.train.Saver()
                kpt = tf.train.latest_checkpoint(self.model_dir)
                if kpt is not None:
                    saver.restore(sess, kpt)
                opt = TFOptimizer._from_grads(loss, sess,
                                              inputs=nest.flatten(dataset._original_tensors),
                                              labels=[],
                                              grads=grads, variables=variables, dataset=dataset,
                                              optim_method=FakeOptimMethod(),
                                              session_config=self._session_config,
                                              model_dir=os.path.join(self.model_dir, "tmp"),
                                              train_op=train_op)
                opt.optimize(end_trigger)
                saver = tf.train.Saver()
                saver.save(sess, self.checkpoint_path, global_step=counter)