def configure_logger(log_dir): logger.configure(log_dir, format_strs=['log']) global tb tb = logger.Logger(log_dir, [ logger.make_output_format('tensorboard', log_dir), logger.make_output_format('csv', log_dir), logger.make_output_format('stdout', log_dir) ]) global log log = logger.log
def configure_logger(log_dir, add_tb=1, add_wb=1, args=None): logger.configure(log_dir, format_strs=['log']) global tb log_types = [ logger.make_output_format('log', log_dir), logger.make_output_format('json', log_dir), logger.make_output_format('stdout', log_dir) ] if add_tb: log_types += [logger.make_output_format('tensorboard', log_dir)] if add_wb: log_types += [logger.make_output_format('wandb', log_dir, args=args)] tb = logger.Logger(log_dir, log_types) global log log = logger.log
def __init__(self, log_interval=503): super(DiagnosticsInfoI, self).__init__() self.log = lp.make_output_format('csv', log_dir, 'bonus' + str(tasks)) self._episode_time = time.time() self._last_time = time.time() self._local_t = 0 self._log_interval = log_interval self._episode_reward = 0 self._episode_length = 0 self._all_rewards = [] self._num_vnc_updates = 0 self._last_episode_id = -1 self.once = True self.log_dict = dict()
def run(args, server): log = lp.make_output_format('csv', args.log_dir, 'bonus' + str(args.task)) env = create_env(args.log_dir, args.task, args.env_id, client_id=str(args.task), remotes=args.remotes) trainer = A3C(env, args.task, args.visualise, log) # Variable names that start with "local" are not saved in checkpoints. if use_tf12_api: variables_to_save = [ v for v in tf.global_variables() if not v.name.startswith("local") ] init_op = tf.variables_initializer(variables_to_save) init_all_op = tf.global_variables_initializer() else: variables_to_save = [ v for v in tf.all_variables() if not v.name.startswith("local") ] init_op = tf.initialize_variables(variables_to_save) init_all_op = tf.initialize_all_variables() saver = FastSaver(variables_to_save) var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) logger.info('Trainable vars:') for v in var_list: logger.info(' %s %s', v.name, v.get_shape()) def init_fn(ses): logger.info("Initializing all parameters.") ses.run(init_all_op) config = tf.ConfigProto(device_filters=[ "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task) ]) logdir = os.path.join(args.log_dir, 'train') if use_tf12_api: summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task) else: summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task) logger.info("Events directory: %s_%s", logdir, args.task) sv = tf.train.Supervisor( is_chief=(args.task == 0), logdir=logdir, saver=saver, summary_op=None, init_op=init_op, init_fn=init_fn, summary_writer=summary_writer, ready_op=tf.report_uninitialized_variables(variables_to_save), global_step=trainer.global_step, save_model_secs=30, save_summaries_secs=30) num_global_steps = 256000000 logger.info( "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " + "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified." ) with sv.managed_session(server.target, config=config) as sess, sess.as_default(): sess.run(trainer.sync) trainer.start(sess, summary_writer) global_step = sess.run(trainer.global_step) logger.info("Starting training at step=%d", global_step) while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps): trainer.process(sess) global_step = sess.run(trainer.global_step) # Ask for all the services to stop. sv.stop() logger.info('reached %s steps. worker stopped.', global_step)