def create_train_step(loss, optimizer, global_step=_USE_GLOBAL_STEP, total_loss_fn=None, update_ops=None, variables_to_train=None, transform_grads_fn=None, summarize_gradients=False, gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP, aggregation_method=None, check_numerics=True): """Creates a train_step that evaluates the gradients and returns the loss. Args: loss: A (possibly nested tuple of) `Tensor` or function representing the loss. optimizer: A tf.Optimizer to use for computing the gradients. global_step: A `Tensor` representing the global step variable. If left as `_USE_GLOBAL_STEP`, then tf.train.get_or_create_global_step() is used. total_loss_fn: Function to call on loss value to access the final item to minimize. update_ops: An optional list of updates to execute. If `update_ops` is `None`, then the update ops are set to the contents of the `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, a warning will be displayed. variables_to_train: an optional list of variables to train. If None, it will default to all tf.trainable_variables(). transform_grads_fn: A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated list. summarize_gradients: Whether or not add summaries for each gradient. gate_gradients: How to gate the computation of gradients. See tf.Optimizer. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. check_numerics: Whether or not we apply check_numerics. Returns: In graph mode: A (possibly nested tuple of) `Tensor` that when evaluated, calculates the current loss, computes the gradients, applies the optimizer, and returns the current loss. In eager mode: A lambda function that when is called, calculates the loss, then computes and applies the gradients and returns the original loss values. Raises: ValueError: if loss is not callable. RuntimeError: if resource variables are not enabled. """ if total_loss_fn is None: total_loss_fn = lambda x: x if not callable(total_loss_fn): raise ValueError('`total_loss_fn` should be a function.') if not common.resource_variables_enabled(): raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR) if not tf.executing_eagerly(): if callable(loss): loss = loss() if callable(variables_to_train): variables_to_train = variables_to_train() # Calculate loss first, then calculate train op, then return the original # loss conditioned on executing the train op. with tf.control_dependencies(tf.nest.flatten(loss)): loss = tf.nest.map_structure( lambda t: tf.identity(t, 'loss_pre_train'), loss) train_op = create_train_op( total_loss_fn(loss), optimizer, global_step=global_step, update_ops=update_ops, variables_to_train=variables_to_train, transform_grads_fn=transform_grads_fn, summarize_gradients=summarize_gradients, gate_gradients=gate_gradients, aggregation_method=aggregation_method, check_numerics=check_numerics) with tf.control_dependencies([train_op]): return tf.nest.map_structure( lambda t: tf.identity(t, 'loss_post_train'), loss) if global_step is _USE_GLOBAL_STEP: global_step = tf.compat.v1.train.get_or_create_global_step() if not callable(loss): raise ValueError('`loss` should be a function in eager mode.') if not isinstance(loss, Future): logging.warning('loss should be an instance of eager_utils.Future') with tf.GradientTape() as tape: loss_value = loss() total_loss_value = total_loss_fn(loss_value) if variables_to_train is None: variables_to_train = tape.watched_variables() elif callable(variables_to_train): variables_to_train = variables_to_train() variables_to_train = tf.nest.flatten(variables_to_train) grads = tape.gradient(total_loss_value, variables_to_train) grads_and_vars = zip(grads, variables_to_train) if transform_grads_fn: grads_and_vars = transform_grads_fn(grads_and_vars) if summarize_gradients: with tf.name_scope('summarize_grads'): add_gradients_summaries(grads_and_vars, global_step) if check_numerics: with tf.name_scope('train_op'): tf.debugging.check_numerics(total_loss_value, 'Loss is inf or nan') optimizer.apply_gradients(grads_and_vars, global_step=global_step) return loss_value
def train( root_dir, agent, environment, training_loops, steps_per_loop=1, additional_metrics=(), # Params for checkpoints, summaries, and logging train_checkpoint_interval=10, policy_checkpoint_interval=10, log_interval=10, summary_interval=10): """A training driver.""" if not common.resource_variables_enabled(): raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir) train_summary_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(batch_size=environment.batch_size), tf_metrics.AverageEpisodeLengthMetric( batch_size=environment.batch_size), ] + list(additional_metrics) # Add to replay buffer and other agent specific observers. replay_buffer = build_replay_buffer(agent, environment.batch_size, steps_per_loop) agent_observers = [replay_buffer.add_batch] + train_metrics driver = dynamic_step_driver.DynamicStepDriver( env=environment, policy=agent.policy, num_steps=steps_per_loop * environment.batch_size, observers=agent_observers) collect_op, _ = driver.run() batch_size = driver.env.batch_size dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=steps_per_loop, single_deterministic_pass=True) trajectories, unused_info = tf.data.experimental.get_single_element(dataset) train_op = agent.train(experience=trajectories) clear_replay_op = replay_buffer.clear() train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, max_to_keep=1, agent=agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), max_to_keep=None, policy=agent.policy, global_step=global_step) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) init_agent_op = agent.initialize() config_saver = utils.GinConfigSaverHook(train_dir, summarize_config=True) config_saver.begin() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) config_saver.after_create_session(sess) global_step_call = sess.make_callable(global_step) global_step_val = global_step_call() sess.run(train_summary_writer.init()) sess.run(collect_op) if global_step_val == 0: # Save an initial checkpoint so the evaluator runs for global_step=0. policy_checkpointer.save(global_step=global_step_val) sess.run(init_agent_op) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) clear_replay_call = sess.make_callable(clear_replay_op) timed_at_step = global_step_val time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(training_loops): # Collect and train. start_time = time.time() collect_call() total_loss, _ = train_step_call() clear_replay_call() global_step_val = global_step_call() time_acc += time.time() - start_time total_loss = total_loss.loss if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val)