Esempio n. 1
0
    def testCreateTrainOpWithTotalLossFn(self):
        inputs, labels = input_fn()
        model = Model('model', Network())
        loss = model.loss_fn(inputs, labels)
        model_2 = Model('model_2', Network())
        loss_2 = model_2.loss_fn(inputs, labels)

        @eager_utils.future_in_eager_mode
        def tuple_loss(loss, loss_2):
            return (loss() if callable(loss) else loss,
                    loss_2() if callable(loss_2) else loss_2)

        tuple_loss_value = tuple_loss(loss, loss_2)

        def first_element(tuple_value):
            return tuple_value[0]

        optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
        loss = eager_utils.create_train_step(tuple_loss_value,
                                             optimizer,
                                             total_loss_fn=first_element)
        expected_loss = 1.098612
        self.evaluate(tf.compat.v1.global_variables_initializer())
        train_step_model_0, train_step_model_1 = self.evaluate(loss)
        self.assertAllClose(train_step_model_0, expected_loss)
        self.assertAllClose(train_step_model_1, expected_loss)
    def testCreateTrainOpWithTotalLossFn(self):
        inputs, labels = input_fn()
        model = Model('model', Network())
        loss = model.loss_fn(inputs, labels)
        model_2 = Model('model_2', Network())
        loss_2 = model_2.loss_fn(inputs, labels)

        @eager_utils.future_in_eager_mode
        def tuple_loss(loss, loss_2):
            return (loss() if callable(loss) else loss,
                    loss_2() if callable(loss_2) else loss_2)

        tuple_loss_value = tuple_loss(loss, loss_2)

        def first_element(tuple_value):
            return tuple_value[0]

        optimizer = tf.train.GradientDescentOptimizer(0.1)
        loss = eager_utils.create_train_step(tuple_loss_value,
                                             optimizer,
                                             total_loss_fn=first_element)
        initial_loss = 1.098612
        final_loss = 1.064379
        self.evaluate(tf.global_variables_initializer())
        train_step_model_0, train_step_model_1 = self.evaluate(loss)
        self.assertAllClose(train_step_model_0, initial_loss)
        self.assertAllClose(train_step_model_1, initial_loss)
        train_step_model_0, train_step_model_1 = self.evaluate(loss)
        self.assertAllClose(train_step_model_0, final_loss)
        # model_1 was not updated since its loss is not being optimized: only
        # the first element output was optimized.
        self.assertAllClose(train_step_model_1, initial_loss)
Esempio n. 3
0
    def _train(self, experience, weights=None, train_step_counter=None):
        # TODO(sfishman): Support batch dimensions >1.
        if experience.step_type.shape[0] != 1:
            raise NotImplementedError(
                'ReinforceAgent does not yet support batch '
                'dimensions greater than 1.')
        experience = nest.map_structure(lambda t: tf.squeeze(t, 0), experience)
        returns = common.compute_returns(experience.reward,
                                         experience.discount)
        if self._debug_summaries:
            tf.contrib.summary.histogram('rewards', experience.reward)
            tf.contrib.summary.histogram('discounts', experience.discount)
            tf.contrib.summary.histogram('returns', returns)

        # TODO(kbnaoop): replace with tensor normalizer.
        if self._normalize_returns:
            ret_mean, ret_var = tf.nn.moments(returns, axes=[0])
            returns = (returns - ret_mean) / (tf.sqrt(ret_var) + 1e-6)
            if self._debug_summaries:
                tf.contrib.summary.histogram('normalized_returns', returns)

        # TODO(kbanoop): remove after changing network interface to accept
        # observations and step_types, instead of time_steps.
        time_step = ts.TimeStep(experience.step_type,
                                tf.zeros_like(experience.reward),
                                tf.zeros_like(experience.discount),
                                experience.observation)
        # TODO(kbanoop): Filter boundary steps.

        loss_info = self._loss(time_step,
                               experience.action,
                               tf.stop_gradient(returns),
                               weights=weights)

        clip_gradients = (tf.contrib.training.clip_gradient_norms_fn(
            self._gradient_clipping) if self._gradient_clipping else None)

        # TODO(sguada): create_train_step should not return a Future.
        loss_info = eager_utils.create_train_step(
            loss_info,
            self._optimizer,
            total_loss_fn=lambda loss_info: loss_info.loss,
            global_step=train_step_counter,
            transform_grads_fn=clip_gradients,
            summarize_gradients=self._summarize_grads_and_vars,
            variables_to_train=lambda: self._actor_network.trainable_weights,
        )

        if isinstance(loss_info, eager_utils.Future):
            loss_info = loss_info()

        if self._summarize_grads_and_vars:
            with tf.name_scope('Variables/'):
                for var in self._actor_network.trainable_weights:
                    tf.contrib.summary.histogram(var.name.replace(':', '_'),
                                                 var)

        return loss_info
Esempio n. 4
0
 def testMultipleCallsTrainStep(self):
     inputs, labels = input_fn()
     model = Model('model', Network())
     loss = model.loss_fn(inputs, labels)
     optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
     train_step = eager_utils.create_train_step(loss, optimizer)
     initial_loss = 1.098612
     final_loss = 1.033917
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(self.evaluate(train_step), initial_loss)
     if context.executing_eagerly():
         for _ in range(5):
             train_step = eager_utils.create_train_step(loss, optimizer)
         train_step = eager_utils.create_train_step(loss, optimizer)
         self.assertAllClose(self.evaluate(train_step), final_loss)
     else:
         for _ in range(5):
             self.evaluate(train_step)
         self.assertAllClose(self.evaluate(train_step), final_loss)
Esempio n. 5
0
 def testLossDecreasesAfterTrainOp(self):
     inputs, labels = input_fn()
     model = Model('model', Network())
     loss = model.loss_fn(inputs, labels)
     optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
     train_step = eager_utils.create_train_step(loss, optimizer)
     initial_loss = 1.098612
     final_loss = 1.064379
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(self.evaluate(train_step), initial_loss)
     self.assertAllClose(self.evaluate(loss), final_loss)
Esempio n. 6
0
 def testVariablesToTrain(self):
     inputs, labels = input_fn()
     model = Model('model', Network())
     if context.executing_eagerly():
         variables_to_train = lambda: model.trainable_variables
     else:
         model(inputs)
         variables_to_train = model.trainable_variables
         self.assertEqual(len(variables_to_train), 2)
     loss = model.loss_fn(inputs, labels)
     optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
     train_step = eager_utils.create_train_step(
         loss, optimizer, variables_to_train=variables_to_train)
     expected_loss = 1.098612
     self.evaluate(tf.compat.v1.global_variables_initializer())
     self.assertAllClose(self.evaluate(train_step), expected_loss)
     self.assertEqual(len(model.trainable_variables), 2)
Esempio n. 7
0
    def _train(self, experience, weights=None):
        loss_info = self._loss(experience, weights=weights)

        transform_grads_fn = None
        if self._gradient_clipping is not None:
            transform_grads_fn = eager_utils.clip_gradient_norms_fn(
                self._gradient_clipping)

        loss_info = eager_utils.create_train_step(
            loss_info,
            self._optimizer,
            total_loss_fn=lambda loss_info: loss_info.loss,
            global_step=self.train_step_counter,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=self._summarize_grads_and_vars,
            variables_to_train=lambda: self._cloning_network.trainable_weights,
        )

        return loss_info
Esempio n. 8
0
    def _train(self, experience, train_step_counter=None, weights=None):
        time_steps, actions, next_time_steps = self._experience_to_transitions(
            experience)

        loss_info = self._loss(time_steps,
                               actions,
                               next_time_steps,
                               td_errors_loss_fn=self._td_errors_loss_fn,
                               gamma=self._gamma,
                               reward_scale_factor=self._reward_scale_factor,
                               weights=weights)

        transform_grads_fn = None
        if self._gradient_clipping is not None:
            transform_grads_fn = tf.contrib.training.clip_gradient_norms_fn(
                self._gradient_clipping)

        loss_info = eager_utils.create_train_step(
            loss_info,
            self._optimizer,
            total_loss_fn=lambda loss_info: loss_info.loss,
            global_step=train_step_counter,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=self._summarize_grads_and_vars,
            variables_to_train=lambda: self._q_network.trainable_weights,
        )

        if isinstance(loss_info, eager_utils.Future):
            loss_info = loss_info()

        # Make sure the update_targets periodically object is only created once.
        if self._target_update_train_op is None:
            with tf.control_dependencies([loss_info.loss]):
                self._target_update_train_op = self._update_targets(
                    self._target_update_tau, self._target_update_period)

        with tf.control_dependencies([self._target_update_train_op]):
            loss_info = nest.map_structure(
                lambda t: tf.identity(t, name='loss_info'), loss_info)

        return loss_info