예제 #1
0
def main(argv):
  del argv  # Unused
  config = DotMap(getattr(config_maml, FLAGS.config))
  print('MAML config: %s' % FLAGS.config)
  tf.logging.info('MAML config: %s', FLAGS.config)
  algo = maml_rl.MAMLReinforcementLearning(config)
  sess_config = tf.ConfigProto(allow_soft_placement=True)
  sess_config.gpu_options.allow_growth = True

  with tf.Session(config=sess_config) as sess:
    algo.init_logging(sess)
    init = tf.global_variables_initializer()
    sess.run(init)
    done = False
    while not done:
      done, _ = algo.train(sess, 10)
    algo.stop_logging()
예제 #2
0
def main(argv):
    del argv  # Unused
    config = _load_config()
    algo = maml_rl.MAMLReinforcementLearning(config,
                                             logdir=FLAGS.model_dir,
                                             save_config=False)

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        algo.restore(sess, FLAGS.model_dir)
        task = config.task_generator()
        task_modifier = config.task_env_modifiers[FLAGS.test_task_index]
        for attr in task_modifier:
            task.__setattr__(attr, task_modifier[attr])

        if FLAGS.eval_meta:
            sum_reward = _rollout_and_save(
                sess,
                task,
                algo.train_policies[0],
                os.path.join(FLAGS.output_dir,
                             'task_{}'.format(FLAGS.test_task_index), 'meta'),
                max_rollout_len=config.max_rollout_len)
            print('Total reward for meta policy is: {}'.format(sum_reward))

        if FLAGS.eval_finetune:
            for step in range(FLAGS.num_finetune_steps):
                tf.logging.info('Finetune step: {}'.format(step))
                algo.finetune(sess, task_modifier)
                sum_reward = _rollout_and_save(
                    sess,
                    task,
                    algo.train_policies[0],
                    os.path.join(FLAGS.output_dir,
                                 'task_{}'.format(FLAGS.test_task_index),
                                 'finetune_{}'.format(step)),
                    max_rollout_len=config.max_rollout_len)
                print('Total reward for fine-tuned policy is: {}'.format(
                    sum_reward))