示例#1
0
    def testNewFocalLossParameters(self):
        """Tests that the loss weight ratio is updated appropriately."""
        original_alpha = 1.0
        original_gamma = 1.0
        new_alpha = 0.3
        new_gamma = 2.0
        hparams = tf.HParams(focal_loss_alpha=new_alpha,
                             focal_loss_gamma=new_gamma)
        pipeline_config_path = os.path.join(self.get_temp_dir(),
                                            "pipeline.config")

        pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
        classification_loss = pipeline_config.model.ssd.loss.classification_loss
        classification_loss.weighted_sigmoid_focal.alpha = original_alpha
        classification_loss.weighted_sigmoid_focal.gamma = original_gamma
        _write_config(pipeline_config, pipeline_config_path)

        configs = config_util.get_configs_from_pipeline_file(
            pipeline_config_path)
        configs = config_util.merge_external_params_with_configs(
            configs, hparams)
        classification_loss = configs["model"].ssd.loss.classification_loss
        self.assertAlmostEqual(
            new_alpha, classification_loss.weighted_sigmoid_focal.alpha)
        self.assertAlmostEqual(
            new_gamma, classification_loss.weighted_sigmoid_focal.gamma)
  def testNewBatchSizeWithClipping(self):
    """Tests that batch size is clipped to 1 from below."""
    original_batch_size = 2
    hparams = tf.HParams(batch_size=0.5)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_config.batch_size = original_batch_size
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    new_batch_size = configs["train_config"].batch_size
    self.assertEqual(1, new_batch_size)  # Clipped to 1.0.
  def testNewBatchSize(self):
    """Tests that batch size is updated appropriately."""
    original_batch_size = 2
    hparams = tf.HParams(batch_size=16)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.train_config.batch_size = original_batch_size
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    new_batch_size = configs["train_config"].batch_size
    self.assertEqual(16, new_batch_size)
  def _assertOptimizerWithNewLearningRate(self, optimizer_name):
    """Asserts successful updating of all learning rate schemes."""
    original_learning_rate = 0.7
    learning_rate_scaling = 0.1
    hparams = tf.HParams(learning_rate=0.15)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    # Constant learning rate.
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
    _update_optimizer_with_constant_learning_rate(optimizer,
                                                  original_learning_rate)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
    constant_lr = optimizer.learning_rate.constant_learning_rate
    self.assertAlmostEqual(hparams.learning_rate, constant_lr.learning_rate)

    # Exponential decay learning rate.
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
    _update_optimizer_with_exponential_decay_learning_rate(
        optimizer, original_learning_rate)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
    exponential_lr = optimizer.learning_rate.exponential_decay_learning_rate
    self.assertAlmostEqual(hparams.learning_rate,
                           exponential_lr.initial_learning_rate)

    # Manual step learning rate.
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
    _update_optimizer_with_manual_step_learning_rate(
        optimizer, original_learning_rate, learning_rate_scaling)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
    manual_lr = optimizer.learning_rate.manual_step_learning_rate
    self.assertAlmostEqual(hparams.learning_rate,
                           manual_lr.initial_learning_rate)
    for i, schedule in enumerate(manual_lr.schedule):
      self.assertAlmostEqual(hparams.learning_rate * learning_rate_scaling**i,
                             schedule.learning_rate)
  def testNewMomentumOptimizerValue(self):
    """Tests that new momentum value is updated appropriately."""
    original_momentum_value = 0.4
    hparams = tf.HParams(momentum_optimizer_value=1.1)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    optimizer_config = pipeline_config.train_config.optimizer.rms_prop_optimizer
    optimizer_config.momentum_optimizer_value = original_momentum_value
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer
    new_momentum_value = optimizer_config.momentum_optimizer_value
    self.assertAlmostEqual(1.0, new_momentum_value)  # Clipped to 1.0.
  def testNewClassificationLocalizationWeightRatio(self):
    """Tests that the loss weight ratio is updated appropriately."""
    original_localization_weight = 0.1
    original_classification_weight = 0.2
    new_weight_ratio = 5.0
    hparams = tf.HParams(
        classification_localization_weight_ratio=new_weight_ratio)
    pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    pipeline_config.model.ssd.loss.localization_weight = (
        original_localization_weight)
    pipeline_config.model.ssd.loss.classification_weight = (
        original_classification_weight)
    _write_config(pipeline_config, pipeline_config_path)

    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    configs = config_util.merge_external_params_with_configs(configs, hparams)
    loss = configs["model"].ssd.loss
    self.assertAlmostEqual(1.0, loss.localization_weight)
    self.assertAlmostEqual(new_weight_ratio, loss.classification_weight)
示例#7
0
def run_experiment(study_hparams=None, trial_handle=None, tuner=None):

    FLAGS = deepcopy(tf.app.flags.FLAGS)

    if FLAGS.use_vizier:
        for key, val in study_hparams.values().items():
            setattr(FLAGS, key, val)

    tf.reset_default_graph()
    np.random.seed(FLAGS.random_seed)
    tf.set_random_seed(FLAGS.random_seed)

    # Initialize env

    env_kwargs = {
        'goal_x': FLAGS.goal_x,
        'min_goal_x': FLAGS.min_goal_x,
        'max_goal_x': FLAGS.max_goal_x,
        'x_threshold': FLAGS.x_threshold,
        'max_reward_for_dist': FLAGS.max_reward_for_dist,
        'reward_per_time_step': FLAGS.reward_per_time_step,
        'fixed_initial_state': FLAGS.fixed_initial_state,
        'reweight_rewards': FLAGS.reweight_rewards
    }
    env = cartpole.make_env(env_kwargs)
    eval_env = cartpole.make_env(env_kwargs)

    if not FLAGS.fixed_env:
        env.env.randomize()

    if trial_handle:
        tensorboard_path = os.path.join(FLAGS.output_dir, trial_handle)
    else:
        tensorboard_path = FLAGS.output_dir
    tf.gfile.MakeDirs(tensorboard_path)

    kwargs = dict(observation_shape=[None] + list(env.observation_space.shape),
                  action_dim=1)
    default_hps = MetaQ.get_default_config().values()

    for key in flags_def:
        if key in default_hps:
            kwargs[key] = getattr(FLAGS, key)

    hps = tf.HParams(**kwargs)

    meta_q = MetaQ(hps, fully_connected_net(FLAGS.nn_arch, FLAGS.activation))
    meta_q.build_graph()

    init_op = tf.global_variables_initializer()

    logger = TensorBoardLogger(tensorboard_path)

    with tf.Session() as sess:
        sess.run(init_op)
        meta_q.init_session(sess)

        inner_loop_buffer = MultiTaskReplayBuffer(len(env.env.goal_positions),
                                                  200000, FLAGS.random_seed)
        outer_loop_buffer = MultiTaskReplayBuffer(len(env.env.goal_positions),
                                                  200000, FLAGS.random_seed)

        pre_update_rewards = []
        post_update_rewards = []
        post_update_greedy_rewards = []
        post_update_q_func = None
        for outer_step in range(FLAGS.outer_loop_steps):
            print('State is ', env.env.state)
            if outer_step % FLAGS.on_policy_steps == 0:
                if FLAGS.fixed_env:
                    goal_positions = [env.env.goal_x]
                else:
                    goal_positions = env.env.goal_positions
                # NOTE: Approximately ~30 to 60 states per trajectory
                inner_loop_buffer = collect_off_policy_data(
                    env, goal_positions, meta_q, post_update_q_func,
                    inner_loop_buffer, FLAGS.inner_loop_n_trajs,
                    FLAGS.inner_loop_data_collection,
                    FLAGS.inner_loop_greedy_epsilon,
                    FLAGS.inner_loop_bolzmann_temp)
                outer_loop_buffer = collect_off_policy_data(
                    env, goal_positions, meta_q, post_update_q_func,
                    outer_loop_buffer, FLAGS.outer_loop_n_trajs,
                    FLAGS.outer_loop_data_collection,
                    FLAGS.outer_loop_greedy_epsilon,
                    FLAGS.outer_loop_bolzmann_temp)

            post_update_greedy_rewards = []

            finetuned_policy = None
            for task_id in range(FLAGS.n_meta_tasks):
                # print('Task: {}'.format(task_id))

                if not FLAGS.fixed_env:
                    env.env.randomize()

                (inner_observations, inner_actions, inner_rewards,
                 inner_next_observations,
                 inner_dones) = inner_loop_buffer.sample(
                     env.env.task_id, FLAGS.inner_loop_n_states)
                # Evaluating true rewards
                post_update_q_func = meta_q.get_post_update_q_function(
                    inner_observations, inner_actions, inner_rewards,
                    inner_next_observations, inner_dones)

                policy = QPolicy(post_update_q_func, epsilon=0.0)

                if outer_step % FLAGS.report_steps == 0 or outer_step >= (
                        FLAGS.outer_loop_steps - 1):
                    _, _, greedy_rewards, _, _ = cartpole_utils.collect_data(
                        env,
                        n_trajs=FLAGS.outer_loop_greedy_eval_n_trajs,
                        policy=policy)
                    post_update_greedy_rewards.append(
                        np.sum(greedy_rewards) /
                        FLAGS.outer_loop_greedy_eval_n_trajs)

                finetuned_policy = policy

                (outer_observations, outer_actions, outer_rewards,
                 outer_next_observations,
                 outer_dones) = outer_loop_buffer.sample(
                     env.env.task_id, FLAGS.outer_loop_n_states)
                meta_q.accumulate_gradient(
                    inner_observations,
                    inner_actions,
                    inner_rewards,
                    inner_next_observations,
                    inner_dones,
                    outer_observations,
                    outer_actions,
                    outer_rewards,
                    outer_next_observations,
                    outer_dones,
                )

            pre_update_loss, post_update_loss = meta_q.run_train_step()

            if not FLAGS.outer_loop_online_target and outer_step % FLAGS.target_update_freq == 0:
                print("updating target network")
                meta_q.update_target_network()

            log_data = dict(
                pre_update_loss=pre_update_loss,
                post_update_loss=post_update_loss,
                goal_x=env.env.goal_x,
            )

            #TODO(hkannan): uncomment this later!!!
            if outer_step % FLAGS.report_steps == 0 or outer_step >= (
                    FLAGS.outer_loop_steps - 1):
                # reward_across_20_tasks = evaluate(
                #     policy, eval_env, meta_q,
                #     inner_loop_n_trajs=FLAGS.inner_loop_n_trajs,
                #     outer_loop_n_trajs=FLAGS.outer_loop_n_trajs, n=21,
                #     weight_rewards=FLAGS.weight_rewards)
                # log_data['reward_mean'] = np.mean(reward_across_20_tasks)
                # log_data['reward_variance'] = np.var(reward_across_20_tasks)
                log_data['post_update_greedy_reward'] = np.mean(
                    post_update_greedy_rewards)
                log_data['post_update_greedy_reward_variance'] = np.var(
                    post_update_greedy_rewards)

            print('Outer step: {}, '.format(outer_step), log_data)
            logger.log_dict(outer_step, log_data)
            # if outer_step % FLAGS.video_report_steps == 0 or outer_step >= (FLAGS.outer_loop_steps - 1):
            #   video_data = {
            #       'env_kwargs': env_kwargs,
            #       'inner_loop_data_collection': FLAGS.inner_loop_data_collection,
            #       'inner_loop_greedy_epsilon': FLAGS.inner_loop_greedy_epsilon,
            #       'inner_loop_bolzmann_temp': FLAGS.inner_loop_bolzmann_temp,
            #       'inner_loop_n_trajs': FLAGS.inner_loop_n_trajs,
            #       'meta_q_kwargs': kwargs,
            #       'weights': meta_q.get_current_weights(),
            #       'tensorboard_path': tensorboard_path,
            #       'filename': 'random_task'
            #   }
            #   reward_across_20_tasks = evaluate(
            #       policy, eval_env, meta_q,
            #       inner_loop_n_trajs=FLAGS.inner_loop_n_trajs,
            #       outer_loop_n_trajs=FLAGS.outer_loop_n_trajs, n=21,
            #       weight_rewards=FLAGS.weight_rewards, video_data=video_data)
            #   log_data['reward_mean'] = np.mean(reward_across_20_tasks)
            #   log_data['reward_variance'] = np.var(reward_across_20_tasks)
            #   logger.log_dict(outer_step, log_data)

            if outer_step >= (FLAGS.outer_loop_steps - 1):
                greedy_reward_path = os.path.join(tensorboard_path, 'reward')
                with gfile.Open(greedy_reward_path, mode='wb') as f:
                    f.write(pickle.dumps(
                        log_data['post_update_greedy_reward']))
            if FLAGS.use_vizier:
                for v in log_data.values():
                    if not np.isfinite(v):
                        tuner.report_done(
                            infeasible=True,
                            infeasible_reason='Nan or inf encountered')
                        return

                if outer_step % FLAGS.report_steps == 0 or outer_step >= (
                        FLAGS.outer_loop_steps - 1):
                    if FLAGS.vizier_objective == 'greedy_reward':
                        objective_value = log_data['post_update_greedy_reward']
                    elif FLAGS.vizier_objective == 'loss':
                        objective_value = post_update_loss
                    elif FLAGS.vizier_objective == 'reward':
                        objective_value = log_data['reward_mean']
                    else:
                        raise ValueError('Unsupported vizier objective!')
                    tuner.report_measure(objective_value=objective_value,
                                         global_step=outer_step,
                                         metrics=log_data)

    if FLAGS.use_vizier:
        tuner.report_done()