Esempio n. 1
0
def train():
    num_parallel_environments = 2
    collect_episodes_per_iteration = 2  # 30

    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment([
            lambda: tf_py_environment.TFPyEnvironment(
                suite_gym.wrap_env(RectEnv()))
        ] * num_parallel_environments))

    print(tf_env.time_step_spec())
    print(tf_env.action_spec())
    print(tf_env.observation_spec())

    preprocessing_layers = {
        'target':
        tf.keras.models.Sequential([
            # tf.keras.applications.MobileNetV2(
            #     input_shape=(64, 64, 1), include_top=False, weights=None),
            tf.keras.layers.Conv2D(1, 6),
            tf.keras.layers.Flatten()
        ]),
        'canvas':
        tf.keras.models.Sequential([
            # tf.keras.applications.MobileNetV2(
            #     input_shape=(64, 64, 1), include_top=False, weights=None),
            tf.keras.layers.Conv2D(1, 6),
            tf.keras.layers.Flatten()
        ]),
        'coord':
        tf.keras.models.Sequential(
            [tf.keras.layers.Dense(5),
             tf.keras.layers.Flatten()])
    }
    preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner)
    value_net = value_network.ValueNetwork(
        tf_env.observation_spec(),
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner)

    tf_agent = ppo_agent.PPOAgent(tf_env.time_step_spec(),
                                  tf_env.action_spec(),
                                  tf.compat.v1.train.AdamOptimizer(),
                                  actor_net=actor_net,
                                  value_net=value_net,
                                  normalize_observations=False,
                                  use_gae=False)
    tf_agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec, batch_size=num_parallel_environments)
    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        tf_agent.collect_policy,
        observers=[replay_buffer.add_batch],
        num_episodes=collect_episodes_per_iteration)

    # print(tf_agent.collect_data_spec)

    def train_step():
        trajectories = replay_buffer.gather_all()
        return tf_agent.train(experience=trajectories)

    collect_driver.run = common.function(collect_driver.run, autograph=False)
    # tf_agent.train = common.function(tf_agent.train, autograph=False)
    # train_step = common.function(train_step)

    # for _ in range(10):
    collect_driver.run()
Esempio n. 2
0
    # env = get_env(name='kitchen')
    env = get_env(name='playpen_reduced',
                  task_list='rc_o',
                  reward_type='sparse')

    base_dir = os.path.abspath(
        'experiments/env_logs/playpen_reduced/symmetric/')
    env_log_dir = os.path.join(base_dir, 'rc_o/traj1/')
    # env = ResetFreeWrapper(env, reset_goal_frequency=500, full_reset_frequency=max_episode_steps)
    env = GoalTerminalResetWrapper(
        env,
        episodes_before_full_reset=max_episode_steps // 500,
        goal_reset_frequency=500)
    # env = Monitor(env, env_log_dir, video_callable=lambda x: x % 1 == 0, force=True)

    env = wrap_env(env)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    tf_env.render = env.render
    time_step_spec = tf_env.time_step_spec()
    action_spec = tf_env.action_spec()
    policy = random_tf_policy.RandomTFPolicy(action_spec=action_spec,
                                             time_step_spec=time_step_spec)
    collect_data_spec = trajectory.Trajectory(
        step_type=time_step_spec.step_type,
        observation=time_step_spec.observation,
        action=action_spec,
        policy_info=policy.info_spec,
        next_step_type=time_step_spec.step_type,
        reward=time_step_spec.reward,
        discount=time_step_spec.discount)
    offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer(
Esempio n. 3
0
def train_eval(
        root_dir,
        env_name='gym_solventx-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        gym_env = gym.make(env_name, config_file=config_file)
        py_env = suite_gym.wrap_env(gym_env, max_episode_steps=100)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_gym_env = gym.make(env_name, config_file=config_file)
        eval_py_env = suite_gym.wrap_env(eval_gym_env, max_episode_steps=100)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        #tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        #eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name), config_file=config_file)

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        if train_sequence_length > 1:
            q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
        else:
            q_net = q_network.QNetwork(tf_env.observation_spec(),
                                       tf_env.action_spec(),
                                       fc_layer_params=fc_layer_params)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())
                saved_model_path = os.path.join(
                    saved_model_dir,
                    'policy_' + ('%d' % global_step.numpy()).zfill(9))
                saved_model.save(saved_model_path)

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
Esempio n. 4
0
        next_dyn_obs = self._dynamics.predict_state(timesteps=cur_dyn_obs, actions=actions.cpu().numpy())
        reward = self._env.compute_reward(cur_dyn_obs, next_dyn_obs, info=DADSEnv.OBS_TYPE.DYNAMICS_OBS)
        return torch.tensor(-reward, device=self._device)

    def start_episode(self):
        self._planner.reset()

    def get_skill(self, ts: TimeStep) -> np.ndarray:
        dyn_obs = self._env.to_dynamics_obs(ts.observation)
        return self._planner.command(state=dyn_obs).cpu().numpy()


class SimpleDynamics:
    _action_size = 2

    @staticmethod
    def predict_state(timesteps, actions):
        return np.clip(timesteps + 0.1*actions, -1, 1)


if __name__ == '__main__':
    env = wrap_env(make_point2d_dads_env())
    provider = MPPISkillProvider(env=env, dynamics=SimpleDynamics(), skills_to_plan=1)
    while True:
        ts = env.reset()
        provider.start_episode()
        for _ in range(30):
            env.render("human")
            action = provider.get_skill(ts=ts)
            ts = env.step(action)
Esempio n. 5
0
def main(_):
    # setting up
    start_time = time.time()
    tf.compat.v1.enable_resource_variables()
    tf.compat.v1.disable_eager_execution()
    logging.set_verbosity(logging.INFO)
    global observation_omit_size, goal_coord, sample_count, iter_count, episode_size_buffer, episode_return_buffer

    root_dir = os.path.abspath(os.path.expanduser(FLAGS.logdir))
    if not tf.io.gfile.exists(root_dir):
        tf.io.gfile.makedirs(root_dir)
    log_dir = os.path.join(root_dir, FLAGS.environment)

    if not tf.io.gfile.exists(log_dir):
        tf.io.gfile.makedirs(log_dir)
    save_dir = os.path.join(log_dir, 'models')
    if not tf.io.gfile.exists(save_dir):
        tf.io.gfile.makedirs(save_dir)

    print('directory for recording experiment data:', log_dir)

    # in case training is paused and resumed, so can be restored
    try:
        sample_count = np.load(os.path.join(log_dir,
                                            'sample_count.npy')).tolist()
        iter_count = np.load(os.path.join(log_dir, 'iter_count.npy')).tolist()
        episode_size_buffer = np.load(
            os.path.join(log_dir, 'episode_size_buffer.npy')).tolist()
        episode_return_buffer = np.load(
            os.path.join(log_dir, 'episode_return_buffer.npy')).tolist()
    except:
        sample_count = 0
        iter_count = 0
        episode_size_buffer = []
        episode_return_buffer = []

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        os.path.join(log_dir, 'train', 'in_graph_data'),
        flush_millis=10 * 1000)
    train_summary_writer.set_as_default()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(True):
        # environment related stuff
        env = do.get_environment(env_name=FLAGS.environment)
        py_env = wrap_env(skill_wrapper.SkillWrapper(
            env,
            num_latent_skills=FLAGS.num_skills,
            skill_type=FLAGS.skill_type,
            preset_skill=None,
            min_steps_before_resample=FLAGS.min_steps_before_resample,
            resample_prob=FLAGS.resample_prob),
                          max_episode_steps=FLAGS.max_env_steps)

        # all specifications required for all networks and agents
        py_action_spec = py_env.action_spec()
        tf_action_spec = tensor_spec.from_spec(
            py_action_spec)  # policy, critic action spec
        env_obs_spec = py_env.observation_spec()
        py_env_time_step_spec = ts.time_step_spec(
            env_obs_spec)  # replay buffer time_step spec
        if observation_omit_size > 0:
            agent_obs_spec = array_spec.BoundedArraySpec(
                (env_obs_spec.shape[0] - observation_omit_size, ),
                env_obs_spec.dtype,
                minimum=env_obs_spec.minimum,
                maximum=env_obs_spec.maximum,
                name=env_obs_spec.name)  # policy, critic observation spec
        else:
            agent_obs_spec = env_obs_spec
        py_agent_time_step_spec = ts.time_step_spec(
            agent_obs_spec)  # policy, critic time_step spec
        tf_agent_time_step_spec = tensor_spec.from_spec(
            py_agent_time_step_spec)

        if not FLAGS.reduced_observation:
            skill_dynamics_observation_size = (
                py_env_time_step_spec.observation.shape[0] - FLAGS.num_skills)
        else:
            skill_dynamics_observation_size = FLAGS.reduced_observation

        # TODO(architsh): Shift co-ordinate hiding to actor_net and critic_net (good for futher image based processing as well)
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_agent_time_step_spec.observation,
            tf_action_spec,
            fc_layer_params=(FLAGS.hidden_layer_size, ) * 2,
            continuous_projection_net=do._normal_projection_net)

        critic_net = critic_network.CriticNetwork(
            (tf_agent_time_step_spec.observation, tf_action_spec),
            observation_fc_layer_params=None,
            action_fc_layer_params=None,
            joint_fc_layer_params=(FLAGS.hidden_layer_size, ) * 2)

        if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:
            reweigh_batches_flag = True
        else:
            reweigh_batches_flag = False

        agent = dads_agent.DADSAgent(
            # DADS parameters
            save_dir,
            skill_dynamics_observation_size,
            observation_modify_fn=do.process_observation,
            restrict_input_size=observation_omit_size,
            latent_size=FLAGS.num_skills,
            latent_prior=FLAGS.skill_type,
            prior_samples=FLAGS.random_skills,
            fc_layer_params=(FLAGS.hidden_layer_size, ) * 2,
            normalize_observations=FLAGS.normalize_data,
            network_type=FLAGS.graph_type,
            num_mixture_components=FLAGS.num_components,
            fix_variance=FLAGS.fix_variance,
            reweigh_batches=reweigh_batches_flag,
            skill_dynamics_learning_rate=FLAGS.skill_dynamics_lr,
            # SAC parameters
            time_step_spec=tf_agent_time_step_spec,
            action_spec=tf_action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            target_update_tau=0.005,
            target_update_period=1,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=FLAGS.agent_lr),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=FLAGS.agent_lr),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=FLAGS.agent_lr),
            td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
            gamma=FLAGS.agent_gamma,
            reward_scale_factor=1. / (FLAGS.agent_entropy + 1e-12),
            gradient_clipping=None,
            debug_summaries=FLAGS.debug,
            train_step_counter=global_step)

        # evaluation policy
        eval_policy = py_tf_policy.PyTFPolicy(agent.policy)

        # collection policy
        if FLAGS.collect_policy == 'default':
            collect_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)
        elif FLAGS.collect_policy == 'ou_noise':
            collect_policy = py_tf_policy.PyTFPolicy(
                ou_noise_policy.OUNoisePolicy(agent.collect_policy,
                                              ou_stddev=0.2,
                                              ou_damping=0.15))

        # relabelling policy deals with batches of data, unlike collect and eval
        relabel_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)

        # constructing a replay buffer, need a python spec
        policy_step_spec = policy_step.PolicyStep(action=py_action_spec,
                                                  state=(),
                                                  info=())

        if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:
            policy_step_spec = policy_step_spec._replace(
                info=policy_step.set_log_probability(
                    policy_step_spec.info,
                    array_spec.ArraySpec(
                        shape=(), dtype=np.float32, name='action_log_prob')))

        trajectory_spec = from_transition(py_env_time_step_spec,
                                          policy_step_spec,
                                          py_env_time_step_spec)
        capacity = FLAGS.replay_buffer_capacity
        # for all the data collected
        rbuffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
            capacity=capacity, data_spec=trajectory_spec)

        if FLAGS.train_skill_dynamics_on_policy:
            # for on-policy data (if something special is required)
            on_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
                capacity=FLAGS.initial_collect_steps + FLAGS.collect_steps +
                10,
                data_spec=trajectory_spec)

        # insert experience manually with relabelled rewards and skills
        agent.build_agent_graph()
        agent.build_skill_dynamics_graph()
        agent.create_savers()

        # saving this way requires the saver to be out the object
        train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            save_dir, 'agent'),
                                                 agent=agent,
                                                 global_step=global_step)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            save_dir, 'policy'),
                                                  policy=agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            save_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=rbuffer)

        setup_time = time.time() - start_time
        print('Setup time:', setup_time)

        with tf.compat.v1.Session().as_default() as sess:
            eval_policy.session = sess
            eval_policy.initialize(None)
            eval_policy.restore(os.path.join(FLAGS.logdir, 'models', 'policy'))

            plotdir = os.path.join(FLAGS.logdir, "plots")
            if not os.path.exists(plotdir):
                os.mkdir(plotdir)
            do.FLAGS = FLAGS
            do.eval_loop(eval_dir=plotdir,
                         eval_policy=eval_policy,
                         plot_name="plot")
Esempio n. 6
0
def get_env(name='frozen_lake', max_episode_steps=50):
  if name == 'frozen_lake':
    return wrap_env(
        FrozenLakeCont(map_name='4x4', continual=False, reset_reward=True),
        max_episode_steps=max_episode_steps)
Esempio n. 7
0
def get_env(name='sawyer_push',
            max_episode_steps=None,
            gym_env_wrappers=(),
            **env_kwargs):

    reset_state_shape = None
    reset_states = None

    eval_metrics = []
    train_metrics = []
    train_metrics_rev = []  # metrics for reverse policy
    value_metrics = []

    # if name == 'sawyer_push':
    #   env = SawyerObject(
    #       random_init=True,
    #       task_type='push',
    #       obs_type='with_goal',
    #       goal_low=(-0.1, 0.8, 0.05),
    #       goal_high=(0.1, 0.9, 0.3),
    #       liftThresh=0.04,
    #       sampleMode='equal',
    #       rewMode='orig',
    #       rotMode='fixed')
    #   env.set_camera_view(view='topview')
    #   env.set_max_path_length(int(1e8))
    #   eval_metrics += [
    #       FailedEpisodes(
    #           failure_function=functools.partial(
    #               sawyer_push_success, episodic=True),
    #           name='EvalSuccessfulEpisodes')
    #   ]
    #   train_metrics += [
    #       FailedEpisodes(
    #           failure_function=functools.partial(
    #               sawyer_push_success, episodic=False),
    #           name='TrainSuccessfulStates')
    #   ]

    #   if name == 'sawyer_door':
    #     env = SawyerDoor(random_init=True, obs_type='with_goal')
    #     env.set_camera_view(view='topview')
    #     env.set_max_path_length(int(1e8))
    #     env.set_reward_type(reward_type=env_kwargs.get('reward_type', 'dense'))
    #     eval_metrics += [
    #         FailedEpisodes(
    #             failure_function=functools.partial(
    #                 sawyer_door_success, episodic=True),
    #             name='EvalSuccessfulEpisodes')
    #     ]
    #     train_metrics += [
    #         FailedEpisodes(
    #             failure_function=functools.partial(
    #                 sawyer_door_success, episodic=False),
    #             name='TrainSuccessfulStates')
    #     ]
    #     # metrics for reverse policy
    #     train_metrics_rev += [
    #         FailedEpisodes(
    #             failure_function=functools.partial(
    #                 sawyer_door_success, episodic=False),
    #             name='TrainSuccessfulStatesRev')
    #     ]
    #     reset_state_shape = (6,)
    #     reset_states = np.array(
    #         [[-0.00356643, 0.4132358, 0.2534339, -0.21, 0.69, 0.15]])
    #     train_metrics += [
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, :2],
    #             state_max=1.,
    #             num_bins=20,
    #             name='EndEffectorHeatmap',
    #         ),
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, 3:5],
    #             state_max=None,
    #             x_range=(-0.25, 0.25),
    #             y_range=(0.4, 0.9),
    #             num_bins=20,
    #             name='DoorXYHeatmap',
    #         ),
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, 6:8],
    #             state_max=None,
    #             x_range=(-0.25, 0.25),
    #             y_range=(0.4, 0.9),
    #             num_bins=20,
    #             name='EndEffectorGoalHeatmap',
    #         ),
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, 9:11],
    #             state_max=None,
    #             x_range=(-0.25, 0.25),
    #             y_range=(0.4, 0.9),
    #             num_bins=20,
    #             name='DoorXYGoalHeatmap',
    #         ),
    #     ]

    #     # metrics to visualize the value function
    #     value_metrics += [
    #         ValueFunctionHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, 3:5],
    #             state_max=None,
    #             x_range=(-0.25, 0.25),
    #             y_range=(0.4, 0.9),
    #             num_bins=20,
    #             name='DoorXYGoalValueHeatmap',
    #         ),
    #         # ValueFunctionHeatmap(
    #         #     trajectory_to_xypos=lambda x: x[:, 3:5],
    #         #     state_max=None,
    #         #     x_range=(-0.25, 0.25),
    #         #     y_range=(0.4, 0.9),
    #         #     num_bins=20,
    #         #     name='DoorXYCombinedHeatmap',
    #         # ),
    #     ]

    #     # metrics for reverse policy
    #     train_metrics_rev += [
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, :2],
    #             state_max=1.,
    #             num_bins=20,
    #             name='EndEffectorHeatmapRev',
    #         ),
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, 3:5],
    #             state_max=None,
    #             x_range=(-0.25, 0.25),
    #             y_range=(0.1, 0.7),
    #             num_bins=20,
    #             name='DoorXYHeatmapRev',
    #         ),
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, 6:8],
    #             state_max=1.,
    #             num_bins=20,
    #             name='EndEffectorGoalHeatmapRev',
    #         ),
    #         StateVisitationHeatmap(
    #             trajectory_to_xypos=lambda x: x[:, 9:11],
    #             state_max=None,
    #             x_range=(-0.25, 0.25),
    #             y_range=(0.1, 0.7),
    #             num_bins=20,
    #             name='DoorXYGoalHeatmapRev',
    #         ),
    #     ]

    if name == 'pusher2d_simple':
        env = PusherEnv()
        eval_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                pusher2d_simple_success, episodic=True),
                           name='EvalSuccessfulEpisodes')
        ]
        train_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                pusher2d_simple_success, episodic=False),
                           name='TrainSuccessfulStates')
        ]

    if name == 'point_mass':
        env = PointMassEnv(**env_kwargs)
        eval_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                point_mass_success, episodic=True),
                           name='EvalSuccessfulEpisodes')
        ]
        train_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                point_mass_success, episodic=False),
                           name='TrainSuccessfulStates')
        ]

        # reverse metrics
        train_metrics_rev += [
            FailedEpisodes(failure_function=functools.partial(
                point_mass_success, episodic=False),
                           name='TrainSuccessfulStatesRev')
        ]
        reset_state_shape = (2, )
        reset_state_by_env_type = {
            'default': np.array([
                0.0,
                0.0,
            ]),
            't': np.array([0.0, 0.0]),
            'y': np.array([0.0, 8.0]),
            'skewed_square': np.array([0.0, -8.0])
        }
        reset_states = np.expand_dims(reset_state_by_env_type[env_kwargs.get(
            'env_type', 'default')],
                                      axis=0)

        train_metrics += [
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],
                state_max=10.,
                num_bins=20,
                name='StateVisitationHeatmap',
            ),
            # distribution of goals: goals are always the last two dimensions
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, -2:],
                state_max=10.,
                num_bins=20,
                name='SelectedGoalHeatmap',
            )
        ]

        train_metrics_rev += [
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],  # pylint: disable=invalid-sequence-index
                state_max=10.,
                num_bins=20,
                name='StateVisitationHeatmapRev',
            ),
            # distribution of goals: goals are always the last two dimensions
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, -2:],  # pylint: disable=invalid-sequence-index
                state_max=10.,
                num_bins=20,
                name='SelectedGoalHeatmapRev',
            )
        ]

    if name == 'point_mass_full_goal':
        env = PointMassFullGoalEnv(**env_kwargs)
        eval_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                point_mass_success, episodic=True),
                           name='EvalSuccessfulEpisodes')
        ]
        train_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                point_mass_success, episodic=False),
                           name='TrainSuccessfulStates')
        ]

        # reverse metrics
        train_metrics_rev += [
            FailedEpisodes(failure_function=functools.partial(
                point_mass_success, episodic=False),
                           name='TrainSuccessfulStatesRev')
        ]
        reset_state_shape = (6, )
        reset_state_by_env_type = {
            'default': np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
            't': np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
            'y': np.array([0.0, 8.0, 0.0, 0.0, 0.0, 0.0]),
            'skewed_square': np.array([0.0, -8.0, 0.0, 0.0, 0.0, 0.0])
        }
        reset_states = np.expand_dims(reset_state_by_env_type[env_kwargs.get(
            'env_type', 'default')],
                                      axis=0)

        train_metrics += [
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],
                state_max=10.,
                num_bins=20,
                name='StateVisitationHeatmap',
            ),
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, 6:8],
                state_max=10.,
                num_bins=20,
                name='SelectedGoalHeatmap',
            )
        ]

        train_metrics_rev += [
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],
                state_max=10.,
                num_bins=20,
                name='StateVisitationHeatmapRev',
            ),
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, 6:8],
                state_max=10.,
                num_bins=20,
                name='SelectedGoalHeatmapRev',
            )
        ]


#   if name == 'kitchen':
#     env = kitchen.Kitchen(**env_kwargs)

#     eval_metrics += [
#         FailedEpisodes(
#             failure_function=functools.partial(
#                 kitchen_microwave_success, episodic=True),
#             name='EvalSuccessfulEpisodes')
#     ]
#     train_metrics += [
#         FailedEpisodes(
#             failure_function=functools.partial(
#                 kitchen_microwave_success, episodic=False),
#             name='TrainSuccessfulStates')
#     ]
#     reset_state_shape = kitchen.initial_state.shape[1:]
#     reset_states = kitchen.initial_state

    if name == 'playpen':
        env = playpen.ContinuousPlayPen(**env_kwargs)

        eval_metrics += [
            FailedEpisodes(failure_function=functools.partial(playpen_success,
                                                              episodic=True),
                           name='EvalSuccessfulAtLastStep'),
            AnyStepGoalMetric(goal_success_fn=functools.partial(
                playpen_success, episodic=False),
                              name='EvalSuccessfulAtAnyStep')
        ]
        train_metrics += [
            FailedEpisodes(failure_function=functools.partial(playpen_success,
                                                              episodic=False),
                           name='TrainSuccessfulStates')
        ]
        reset_state_shape = playpen.initial_state.shape[1:]
        reset_states = playpen.initial_state.copy()

        # heatmap visualization
        task_list = env_kwargs.get('task_list', 'rc_o').split('-')
        interest_objects = []
        for task in task_list:
            subtask_list = task.split('__')
            interest_objects += [subtask[:2] for subtask in subtask_list]

        train_metrics += [
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],
                state_max=3.,
                num_bins=20,
                name='GripperHeatmap',
            ),
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, 12:14],
                state_max=3.,
                num_bins=20,
                name='GripperGoalHeatmap',
            ),
        ]

        # metrics to visualize the value function
        value_metrics += [
            ValueFunctionHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],  # pylint: disable=invalid-sequence-index
                state_max=3.,
                num_bins=20,
                name='GripperGoalValueHeatmap',
            ),
        ]

        obj_to_name = {
            'rc': 'RedCube',
            'bc': 'BlueCube',
            'ks': 'BlackSphere',
            'yr': 'YellowCylinder'
        }
        obj_to_idx = {'rc': [2, 4], 'bc': [4, 6], 'ks': [6, 8], 'yr': [8, 10]}
        for obj_code in list(set(interest_objects)):
            state_ids = obj_to_idx[obj_code]
            heatmap_name = obj_to_name[obj_code]
            train_metrics += [
                StateVisitationHeatmap(
                    trajectory_to_xypos=lambda x: x[:, state_ids[0]:state_ids[
                        1]],  # pylint: disable=cell-var-from-loop
                    state_max=3.,
                    num_bins=20,
                    name=heatmap_name + 'Heatmap',
                ),
                StateVisitationHeatmap(
                    trajectory_to_xypos=lambda x: x[:, state_ids[0] + 12:
                                                    state_ids[1]  # pylint: disable=cell-var-from-loop, g-long-lambda
                                                    + 12],
                    state_max=3.,
                    num_bins=20,
                    name=heatmap_name + 'GoalHeatmap',
                ),
            ]
            value_metrics += [
                ValueFunctionHeatmap(
                    trajectory_to_xypos=lambda x: x[:, state_ids[0]:state_ids[
                        1]],  # pylint: disable=cell-var-from-loop
                    state_max=3.,
                    num_bins=20,
                    name=heatmap_name + 'GoalValueHeatmap',
                ),
            ]

    if name == 'playpen_reduced':
        env = playpen_reduced.ContinuousPlayPen(**env_kwargs)

        eval_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                playpen_reduced_success, episodic=True),
                           name='EvalSuccessfulAtLastStep'),
            AnyStepGoalMetric(goal_success_fn=functools.partial(
                playpen_reduced_success, episodic=False),
                              name='EvalSuccessfulAtAnyStep')
        ]
        train_metrics += [
            FailedEpisodes(failure_function=functools.partial(
                playpen_reduced_success, episodic=False),
                           name='TrainSuccessfulStates')
        ]
        reset_state_shape = playpen_reduced.initial_state.shape[1:]
        reset_states = playpen_reduced.initial_state.copy()

        train_metrics += [
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],
                state_max=3.,
                num_bins=20,
                name='GripperHeatmap',
            ),
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, 6:8],
                state_max=3.,
                num_bins=20,
                name='GripperGoalHeatmap',
            ),
        ]

        # metrics to visualize the value function
        value_metrics += [
            ValueFunctionHeatmap(
                trajectory_to_xypos=lambda x: x[:, :2],
                state_max=3.,
                num_bins=20,
                name='GripperGoalValueHeatmap',
            ),
        ]

        train_metrics += [
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, 2:4],
                state_max=3.,
                num_bins=20,
                name='RedCubeHeatmap',
            ),
            StateVisitationHeatmap(
                trajectory_to_xypos=lambda x: x[:, 8:10],
                state_max=3.,
                num_bins=20,
                name='RedCubeGoalHeatmap',
            ),
        ]
        value_metrics += [
            ValueFunctionHeatmap(
                trajectory_to_xypos=lambda x: x[:, 2:4],
                state_max=3.,
                num_bins=20,
                name='RedCubeGoalValueHeatmap',
            ),
        ]

    return wrap_env(
        env,
        max_episode_steps=max_episode_steps,
        gym_env_wrappers=gym_env_wrappers), train_metrics, eval_metrics, {
            'reset_state_shape': reset_state_shape,
            'reset_states': reset_states,
            'train_metrics_rev': train_metrics_rev,
            'value_fn_viz_metrics': value_metrics
        }
Esempio n. 8
0
def train_eval(
        root_dir,
        # env_name='HalfCheetah-v2',
        # env_load_fn=suite_mujoco.load,
        env_load_fn=None,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=int(1e7),
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,  # use_tf_functions=False,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)

        # eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        # tf_env = tf_py_environment.TFPyEnvironment(
        #     parallel_py_environment.ParallelPyEnvironment(
        #         [lambda: env_load_fn(env_name)] * num_parallel_environments))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.wrap_env(RectEnv()))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: suite_gym.wrap_env(RectEnv())] *
                num_parallel_environments))

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        preprocessing_layers = {
            'target':
            tf.keras.models.Sequential([
                # tf.keras.applications.MobileNetV2(
                #     input_shape=(64, 64, 1), include_top=False, weights=None),
                # tf.keras.layers.Conv2D(1, 6),
                easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)),
                tf.keras.layers.Flatten()
            ]),
            'canvas':
            tf.keras.models.Sequential([
                # tf.keras.applications.MobileNetV2(
                #     input_shape=(64, 64, 1), include_top=False, weights=None),
                # tf.keras.layers.Conv2D(1, 6),
                easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)),
                tf.keras.layers.Flatten()
            ]),
            'coord':
            tf.keras.models.Sequential([
                tf.keras.layers.Dense(64),
                tf.keras.layers.Dense(64),
                tf.keras.layers.Flatten()
            ])
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            max_to_keep=5,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  max_to_keep=5,
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Esempio n. 9
0
    print(f'Found policy of type:{type(saved_policy)}!')
    print(f'Testing for {num_episodes} episodes!')
    returns = []
    for episode in range(num_episodes):
        time_step = eval_tf_env.reset()
        print(f'Initial Time step:\n{time_step}')
        episode_return = 0.0
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = eval_tf_env.step(action_step.action)
            print(
                f'Reward:{time_step.reward},Observation:{time_step.observation}'
            )
            episode_return += time_step.reward

        print(f'Total return at episode {episode+1}:{episode_return}')
        returns.append(episode_return.numpy())
    print(f'List of returns after {num_episodes} episodes:{returns}')
    print(
        f'Average return:{np.mean(returns):.3f},Standard deviation:{np.std(returns):.3f}'
    )


eval_gym_env = gym.make(env_name, config_file=config_file)
eval_py_env = suite_gym.wrap_env(eval_gym_env, max_episode_steps=100)
eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

saved_policy = tf.compat.v2.saved_model.load(policy_dir)

test_agent(saved_policy, eval_tf_env)