コード例 #1
0
ファイル: gdqn.py プロジェクト: DailyActie/AI_APP_NLP-KG-A2C
def configure_logger(log_dir):
    logger.configure(log_dir, format_strs=['log'])
    global tb
    tb = logger.Logger(log_dir, [
        logger.make_output_format('tensorboard', log_dir),
        logger.make_output_format('csv', log_dir),
        logger.make_output_format('stdout', log_dir)
    ])
    global log
    log = logger.log
コード例 #2
0
ファイル: train.py プロジェクト: princeton-nlp/calm-textgame
def configure_logger(log_dir, add_tb=1, add_wb=1, args=None):
    logger.configure(log_dir, format_strs=['log'])
    global tb
    log_types = [
        logger.make_output_format('log', log_dir),
        logger.make_output_format('json', log_dir),
        logger.make_output_format('stdout', log_dir)
    ]
    if add_tb: log_types += [logger.make_output_format('tensorboard', log_dir)]
    if add_wb:
        log_types += [logger.make_output_format('wandb', log_dir, args=args)]
    tb = logger.Logger(log_dir, log_types)
    global log
    log = logger.log
コード例 #3
0
ファイル: envs.py プロジェクト: arjunmanoharan/FIGAR
 def __init__(self, log_interval=503):
     super(DiagnosticsInfoI, self).__init__()
     self.log = lp.make_output_format('csv', log_dir, 'bonus' + str(tasks))
     self._episode_time = time.time()
     self._last_time = time.time()
     self._local_t = 0
     self._log_interval = log_interval
     self._episode_reward = 0
     self._episode_length = 0
     self._all_rewards = []
     self._num_vnc_updates = 0
     self._last_episode_id = -1
     self.once = True
     self.log_dict = dict()
コード例 #4
0
def run(args, server):

    log = lp.make_output_format('csv', args.log_dir, 'bonus' + str(args.task))
    env = create_env(args.log_dir,
                     args.task,
                     args.env_id,
                     client_id=str(args.task),
                     remotes=args.remotes)
    trainer = A3C(env, args.task, args.visualise, log)

    # Variable names that start with "local" are not saved in checkpoints.
    if use_tf12_api:
        variables_to_save = [
            v for v in tf.global_variables() if not v.name.startswith("local")
        ]
        init_op = tf.variables_initializer(variables_to_save)
        init_all_op = tf.global_variables_initializer()
    else:
        variables_to_save = [
            v for v in tf.all_variables() if not v.name.startswith("local")
        ]
        init_op = tf.initialize_variables(variables_to_save)
        init_all_op = tf.initialize_all_variables()
    saver = FastSaver(variables_to_save)

    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 tf.get_variable_scope().name)
    logger.info('Trainable vars:')
    for v in var_list:
        logger.info('  %s %s', v.name, v.get_shape())

    def init_fn(ses):
        logger.info("Initializing all parameters.")
        ses.run(init_all_op)

    config = tf.ConfigProto(device_filters=[
        "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)
    ])
    logdir = os.path.join(args.log_dir, 'train')

    if use_tf12_api:
        summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
    else:
        summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)

    logger.info("Events directory: %s_%s", logdir, args.task)
    sv = tf.train.Supervisor(
        is_chief=(args.task == 0),
        logdir=logdir,
        saver=saver,
        summary_op=None,
        init_op=init_op,
        init_fn=init_fn,
        summary_writer=summary_writer,
        ready_op=tf.report_uninitialized_variables(variables_to_save),
        global_step=trainer.global_step,
        save_model_secs=30,
        save_summaries_secs=30)

    num_global_steps = 256000000

    logger.info(
        "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. "
        +
        "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified."
    )
    with sv.managed_session(server.target,
                            config=config) as sess, sess.as_default():
        sess.run(trainer.sync)
        trainer.start(sess, summary_writer)
        global_step = sess.run(trainer.global_step)
        logger.info("Starting training at step=%d", global_step)
        while not sv.should_stop() and (not num_global_steps
                                        or global_step < num_global_steps):
            trainer.process(sess)
            global_step = sess.run(trainer.global_step)

    # Ask for all the services to stop.
    sv.stop()
    logger.info('reached %s steps. worker stopped.', global_step)