예제 #1
0
'''
For the main training loop, instead of calling the get_next() method, we will use
a tf.data.Dataset. This way, we can benefit from the power of the data API (parallelism
and prefetching). For this, we call the replay buffer's as_dataset() method.

We will sample batches of 64 trajectories at each training step, each with 2 steps
(i.e. 2 steps = 1 full transition, including the next step's observation). This
dataset will process 3 elements in parallel, and prefetch 3 batches.

'''
dataset = replay_buffer.as_dataset(sample_batch_size=64,
                                   num_steps=2,
                                   num_parallel_calls=3).prefetch(3)

# To speed up training, convert the main functions to TF Functions.
collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)


# we will save the agent policy periodically
def save_agent_policy():
    now = datetime.datetime.now()
    policy_dir = config.POLICIES_PATH + now.strftime("%m%d%Y-%H%M%S")
    os.mkdir(policy_dir)
    tf_policy_saver = policy_saver.PolicySaver(agent.policy)
    tf_policy_saver.save(policy_dir)
    print(">>>Policy saved in ", policy_dir)


'''
And now we are ready to run the main loop!
예제 #2
0
def train_agent(iterations, modeldir, logdir, policydir):
    """Train and convert the model using TF Agents."""

    train_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2
    )
    eval_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2
    )

    train_env = tf_py_environment.TFPyEnvironment(train_py_env)
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    # Alternatively you could use ActorDistributionNetwork as actor_net
    actor_net = tfa.networks.Sequential(
        [
            tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE], [BOARD_SIZE**2]),
            tf.keras.layers.Dense(FC_LAYER_PARAMS, activation="relu"),
            tf.keras.layers.Dense(BOARD_SIZE**2),
            tf.keras.layers.Lambda(lambda t: tfp.distributions.Categorical(logits=t)),
        ],
        input_spec=train_py_env.observation_spec(),
    )

    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

    train_step_counter = tf.Variable(0)

    tf_agent = reinforce_agent.ReinforceAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        actor_network=actor_net,
        optimizer=optimizer,
        normalize_returns=True,
        train_step_counter=train_step_counter,
    )

    tf_agent.initialize()

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    tf_policy_saver = policy_saver.PolicySaver(collect_policy)

    # Use reverb as replay buffer
    replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
    table = reverb.Table(
        REPLAY_BUFFER_TABLE_NAME,
        max_size=REPLAY_BUFFER_CAPACITY,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature,
    )  # specify signature here for validation at insertion time

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        tf_agent.collect_data_spec,
        sequence_length=None,
        table_name=REPLAY_BUFFER_TABLE_NAME,
        local_server=reverb_server,
    )

    replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver(
        replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME, REPLAY_BUFFER_CAPACITY
    )

    # Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return_and_steps(
        eval_env, tf_agent.policy, NUM_EVAL_EPISODES
    )

    summary_writer = tf.summary.create_file_writer(logdir)

    for i in range(iterations):
        # Collect a few episodes using collect_policy and save to the replay buffer.
        collect_episode(
            train_py_env,
            collect_policy,
            COLLECT_EPISODES_PER_ITERATION,
            replay_buffer_observer,
        )

        # Use data from the buffer and update the agent's network.
        iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
        trajectories, _ = next(iterator)
        tf_agent.train(experience=trajectories)
        replay_buffer.clear()

        logger = tf.get_logger()
        if i % EVAL_INTERVAL == 0:
            avg_return, avg_episode_length = compute_avg_return_and_steps(
                eval_env, eval_policy, NUM_EVAL_EPISODES
            )
            with summary_writer.as_default():
                tf.summary.scalar("Average return", avg_return, step=i)
                tf.summary.scalar("Average episode length", avg_episode_length, step=i)
                summary_writer.flush()
            logger.info(
                "iteration = {0}: Average Return = {1}, Average Episode Length = {2}".format(
                    i, avg_return, avg_episode_length
                )
            )

    summary_writer.close()

    tf_policy_saver.save(policydir)
예제 #3
0
def simulate():
    # Set up the environments for the agent to train and test its performance
    envTrain = ComputerSnake.Snake()
    envEval = ComputerSnake.Snake(persistence = True)

    # Convert and wrap in TFPyEnvironment training and evaluation environments
    train_env = tf_py_environment.TFPyEnvironment(envTrain)
    eval_env = tf_py_environment.TFPyEnvironment(envEval)

    # Set up q network with necessary parameters
    fc_layer_params = (100,)
    q_net = q_network.QNetwork(
        train_env.observation_spec(),
        train_env.action_spec(),
        fc_layer_params=fc_layer_params
    )
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) # look up
    train_step_counter = tf.Variable(0)

    # Set up and initialize the DQN learning agent. It takes in the time_step spec,
    # action spec, the q network, the optimizer, a loss function, and train_step_counter
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer, # look up
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter
    )
    agent.initialize()

    # Set up policies the agent can use
    eval_policy = agent.policy
    collect_policy = agent.collect_policy

    # Policy which randomly selects actions for each step
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                    train_env.action_spec())

    #Buffer to store previous states
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

    # Dataset generates trajectories with shape [Bx2x...] This is so that the agent has access to both the current
    # and previous state to compute loss. Parallel calls and prefetching are used to optimize process.
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    agent.train = common.function(agent.train)

    # Reset the train step
    agent.train_step_counter.assign(0)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)

    # We initially fill the replay buffer with 100 trajectories to help the assistant
    collect_data(train_env, random_policy, replay_buffer, steps=5000)
    train_env.reset()

    # Here, we run the simulation to train the agent
    scores_list = []
    num_steps_arr = []
    for currStep in range(num_iterations):
        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        # Number of training steps so far
        step = agent.train_step_counter.numpy()

        # Prints every 1000 steps made by the training agent
        if step % log_interval == 0:
           print('Moves made = {0}'.format(step))

        # Evaluates the agent's policy every 5000 steps, prints results,
        # ands saves the results for later so they can be plotted
        if step % eval_interval == 0:
          avg_return = 0
          for i in range(num_eval_episodes):
              curr_return = compute_avg_return(eval_env, agent.policy, 1)
              scores_list.append(curr_return)
              num_steps_arr.append(currStep)
              avg_return += curr_return
          avg_return = avg_return/num_eval_episodes
          print('step = {0}: Average Return = {1}'.format(step, avg_return))
    plt.scatter(num_steps_arr, scores_list)
    plt.xlabel('Number of Steps Trained')
    plt.ylabel('Score')
    plt.title('Snake Reinforcement Learning')
    plt.show()
    # initial_collect_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec())
    #
    # initial_driver = dynamic_step_driver.DynamicStepDriver(
    #     tf_env,
    #     initial_collect_policy,
    #     observers=[replay_buffer.add_batch, ShowProgress(INITIAL_COLLECT_STEPS)],
    #     num_steps=INITIAL_COLLECT_STEPS
    # )
    # final_time_step, final_policy_state = initial_driver.run()

    dataset = replay_buffer.\
        as_dataset(sample_batch_size=BATCH_SIZE, num_steps=2, num_parallel_calls=3).\
        prefetch(3)

    agent.train = common.function(agent.train)

    all_train_loss = []
    all_metrics = []
    returns = []

    checkpoint_dir = "checkpoints/checkpoint_50k"
    train_checkpointer = common.Checkpointer(ckpt_dir=checkpoint_dir,
                                             max_to_keep=1,
                                             agent=agent,
                                             policy=agent.policy,
                                             replay_buffer=replay_buffer,
                                             global_step=train_step)
    # train_checkpointer.initialize_or_restore()
    # train_step = tf.compat.v1.train.get_global_step()
    policy_save_handler = policy_saver.PolicySaver(agent.policy)
예제 #5
0
def main():
    # Create train and evaluation environments for Tensorflow
    train_py_env = Environment.Environment()
    train_env = tf_py_environment.TFPyEnvironment(train_py_env)

    eval_py_env = Environment.Environment()
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)


    # utils.validate_py_environment(train_py_env, episodes=5)

    # Set up an agent
    # Decide on layers of a network
    fc_layer_params = (50, 200, 25, 6)
    #conv_layer_params = [(4, 4, 1), (8, 4, 2)]
    # QNetwork predicts QValues (expected returns) for all actions based on observation on the given environment
    q_net = q_network.QNetwork(train_env.observation_spec(),
                               train_env.action_spec(),
                               #conv_layer_params=conv_layer_params,
                               fc_layer_params=fc_layer_params)
    # Initialize DQN Agent on the train environment steps, actions, QNetwork, Adam Optimizer, loss function & train step counter
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    # Variable maintains shared, persistent state manipulated by a program.
    # 0 is the initial value.
    # After construction, the type and shape of the variable are fixed.
    train_step_counter = tf.Variable(0)
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        epsilon_greedy=0.4,  #TODO tune this
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter,
        #boltzmann_temperature=0.1,
        summarize_grads_and_vars=True)

    agent.initialize()

    # Policies

    """A policy defines the way an agent acts in an environment. 
    Typically, the goal of RL is to train the underlying model until the policy produces the desired outcome.
    
    Agents contain two policies:
    agent.policy — The main policy that is used for evaluation and deployment.
    agent.collect_policy — A second policy that is used for data collection.
    """

    # tf_agents.policies.random_tf_policy creates a policy which will randomly select an action for each time_step (independent of agent)
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())

    # Baseline average return of the moves based on random_policy (random actions of an agent)
    print(compute_avg_return(eval_env, random_policy, num_eval_episodes))

    # Replay buffer

    # The replay buffer keeps track of data collected from the environment.
    # This tutorial uses tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer, as it is the most common.
    # The constructor requires the specs for the data it will be collecting.
    # This is available from the agent using the collect_data_spec method.
    # The batch size and maximum buffer length are also required.

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_max_length,
        dataset_window_shift=1)

    # The agent needs access to the replay buffer.
    # This is provided by creating an iterable tf.data.Dataset pipeline which will feed data to the agent.
    # Each row of the replay buffer only stores a single observation step.
    # But since the DQN Agent needs both the current and next observation to compute the loss,
    # the dataset pipeline will sample two adjacent rows for each item in the batch (num_steps=2).
    # This dataset is also optimized by running parallel calls and prefetching data.

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        single_deterministic_pass=False,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # Train the agent

    agent.train = common.function(agent.train)

    # Reset the train step
    agent.train_step_counter.assign(0)

    collect_data(train_env, random_policy, replay_buffer, steps=10000)

    for _ in range(num_iterations):

        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        step = agent.train_step_counter.numpy()
        if step % log_interval == 0:
            avg_return = compute_avg_return(eval_env, agent.policy, 3) #TODO
            if not os.path.exists("eval_data"):
                os.makedirs("eval_data")
            path = os.path.join("eval_data", f'Eval_data.step{step // log_interval}.txt')
            with open(path, 'w') as f:
                for move in eval_py_env.all_moves:
                    print(str(move), file=f)
            eval_py_env.all_moves = []
            print('step = {0}: loss = {1}, Average Return: {2}'.format(step, train_loss, avg_return))
def train_eval(
        root_dir,
        env_name='MultiGrid-Empty-5x5-v0',
        env_load_fn=multiagent_gym_suite.load,
        random_seed=0,
        # Architecture params
        actor_fc_layers=(64, 64),
        value_fc_layers=(64, 64),
        lstm_size=(64, ),
        conv_filters=64,
        conv_kernel=3,
        direction_fc=5,
        entropy_regularization=0.,
        use_attention_networks=False,
        # Specialized agents
        inactive_agent_ids=tuple(),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=5,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=2,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=2,
        eval_interval=5,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        log_interval=10,
        summary_interval=10,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=True,
        summarize_grads_and_vars=True,
        eval_metrics_callback=None,
        reinit_checkpoint_dir=None,
        debug=True):
    """A simple train and eval for PPO."""
    tf.compat.v1.enable_v2_behavior()

    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    if debug:
        logging.info('In debug mode, turning tf_functions off')
        use_tf_functions = False

    for a in inactive_agent_ids:
        logging.info('Fixing and not training agent %d', a)

    # Load multiagent gym environment and determine number of agents
    gym_env = env_load_fn(env_name)
    n_agents = gym_env.n_agents

    # Set up logging
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        multiagent_metrics.AverageReturnMetric(n_agents,
                                               buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        logging.info('Creating %d environments...', num_parallel_environments)
        wrappers = []
        if use_attention_networks:
            wrappers = [
                lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size)
            ]

        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(env_name,
                        gym_kwargs=dict(seed=random_seed),
                        gym_env_wrappers=wrappers))
        # pylint: disable=g-complex-comprehension
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                functools.partial(env_load_fn,
                                  environment_name=env_name,
                                  gym_env_wrappers=wrappers,
                                  gym_kwargs=dict(seed=random_seed * 1234 + i))
                for i in range(num_parallel_environments)
            ]))

        logging.info('Preparing to train...')
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            multiagent_metrics.AverageReturnMetric(
                n_agents, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments)
        ]

        logging.info('Creating agent...')
        tf_agent = multiagent_ppo.MultiagentPPO(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            n_agents=n_agents,
            learning_rate=learning_rate,
            actor_fc_layers=actor_fc_layers,
            value_fc_layers=value_fc_layers,
            lstm_size=lstm_size,
            conv_filters=conv_filters,
            conv_kernel=conv_kernel,
            direction_fc=direction_fc,
            entropy_regularization=entropy_regularization,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            inactive_agent_ids=inactive_agent_ids,
            use_attention_networks=use_attention_networks)
        tf_agent.initialize()
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        logging.info('Allocating replay buffer ...')
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        logging.info('RB capacity: %i', replay_buffer.capacity)

        # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is
        # reinitialized. The other agents are novices.
        # Otherwise, all agents are reinitialized from train_dir.
        if reinit_checkpoint_dir:
            reinit_checkpointer = common.Checkpointer(
                ckpt_dir=reinit_checkpoint_dir,
                agent=tf_agent,
            )
            reinit_checkpointer.initialize_or_restore()
            temp_dir = os.path.join(train_dir, 'tmp')
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir,
                agent=tf_agent.agents[:-1],
            )
            agent_checkpointer.save(global_step=0)
            tf_agent = multiagent_ppo.MultiagentPPO(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                n_agents=n_agents,
                learning_rate=learning_rate,
                actor_fc_layers=actor_fc_layers,
                value_fc_layers=value_fc_layers,
                lstm_size=lstm_size,
                conv_filters=conv_filters,
                conv_kernel=conv_kernel,
                direction_fc=direction_fc,
                entropy_regularization=entropy_regularization,
                num_epochs=num_epochs,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                inactive_agent_ids=inactive_agent_ids,
                non_learning_agents=list(range(n_agents - 1)),
                use_attention_networks=use_attention_networks)
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir, agent=tf_agent.agents[:-1])
            agent_checkpointer.initialize_or_restore()
            tf.io.gfile.rmtree(temp_dir)
            eval_policy = tf_agent.policy
            collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=multiagent_metrics.MultiagentMetricsGroup(
                train_metrics, 'train_metrics'))
        if not reinit_checkpoint_dir:
            train_checkpointer.initialize_or_restore()
        logging.info('Successfully initialized train checkpointer')

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)
        logging.info('Successfully initialized policy saver.')

        print('Using TFDriver')
        if use_attention_networks:
            collect_driver = utils.StateTFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)
        else:
            collect_driver = tf_driver.TFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0

        # Save operative config as late as possible to include used configurables.
        if global_step.numpy() == 0:
            config_filename = os.path.join(
                train_dir,
                'operative_config-{}.gin'.format(global_step.numpy()))
            with tf.io.gfile.GFile(config_filename, 'wb') as f:
                f.write(gin.operative_config_str())

        total_episodes = 0
        logging.info('Commencing train loop!')
        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()

            # Evaluation
            if global_step_val % eval_interval == 0:
                if debug:
                    logging.info('Performing evaluation at step %d',
                                 global_step_val)
                results = multiagent_metrics.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                    use_function=use_tf_functions,
                    use_attention_networks=use_attention_networks)
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                multiagent_metrics.log_metrics(eval_metrics)

            # Collect data
            if debug:
                logging.info('Collecting at step %d', global_step_val)
            start_time = time.time()
            time_step = tf_env.reset()
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)
            if use_attention_networks:
                # Attention networks require previous policy state to compute attention
                # weights.
                time_step.observation['policy_state'] = (
                    policy_state['actor_network_state'][0],
                    policy_state['actor_network_state'][1])
            collect_driver.run(time_step, policy_state)
            collect_time += time.time() - start_time

            total_episodes += collect_episodes_per_iteration
            if debug:
                logging.info('Have collected a total of %d episodes',
                             total_episodes)

            # Train
            if debug:
                logging.info('Training at step %d', global_step_val)
            start_time = time.time()
            total_loss, extra_loss = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    logging.info(
                        'Loss diverged for too many timesteps, breaking...')
                    break
            else:
                loss_divergence_counter = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, total loss = %f', global_step_val,
                             total_loss)
                for a in range(n_agents):
                    if not inactive_agent_ids or a not in inactive_agent_ids:
                        logging.info('Loss for agent %d = %f', a,
                                     extra_loss[a].loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        results = multiagent_metrics.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
            use_function=use_tf_functions,
            use_attention_networks=use_attention_networks)
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        multiagent_metrics.log_metrics(eval_metrics)
예제 #7
0
def train_eval(
    root_dir,
    env_name="HalfCheetah-v2",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=100000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=500000000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    relabel_type=None,
    num_future_states=4,
    max_episode_steps=100,
    random_seed=0,
    eval_task_list=None,
    constant_task=None,  # Whether to train on a single task
    clip_critic=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    if relabel_type == "none":
        relabel_type = None
    assert relabel_type in [None, "future", "last", "soft", "random"]
    if constant_task:
        assert relabel_type is None
    if eval_task_list is None:
        eval_task_list = []
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        utils.AverageSuccessMetric(max_episode_steps=max_episode_steps,
                                   buffer_size=num_eval_episodes),
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env, task_distribution = utils.get_env(env_name,
                                                  constant_task=constant_task)
        eval_tf_env, _ = utils.get_env(env_name,
                                       max_episode_steps,
                                       constant_task=constant_task)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=utils.normal_projection_net,
        )
        if isinstance(clip_critic, float):
            output_activation_fn = lambda x: clip_critic * tf.sigmoid(x)
        elif isinstance(clip_critic, tuple):
            assert len(clip_critic) == 2
            min_val, max_val = clip_critic
            output_activation_fn = (
                lambda x:  # pylint: disable=g-long-lambda
                (max_val - min_val) * tf.sigmoid(x) + min_val)
        else:
            output_activation_fn = None
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            output_activation_fn=output_activation_fn,
        )

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = relabelling_replay_buffer.GoalRelabellingReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
            task_distribution=task_distribution,
            actor=actor_net,
            critic=critic_net,
            gamma=gamma,
            relabel_type=relabel_type,
            sample_batch_size=batch_size,
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
            num_future_states=num_future_states,
        )

        env_steps = tf_metrics.EnvironmentSteps(prefix="Train")
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix="Train"),
            env_steps,
            utils.AverageSuccessMetric(
                prefix="Train",
                max_episode_steps=max_episode_steps,
                buffer_size=num_eval_episodes,
            ),
            tf_metrics.AverageReturnMetric(
                prefix="Train",
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix="Train",
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )

        train_checkpointer.initialize_or_restore()

        data_collector = utils.DataCollector(
            tf_env,
            tf_agent.collect_policy,
            replay_buffer,
            max_episode_steps=max_episode_steps,
            observers=train_metrics,
        )

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train)
        else:
            tf.config.experimental_run_functions_eagerly(True)

        # Save the config string as late as possible to catch
        # as many object instantiations as possible.
        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        # Collect initial replay data.
        logging.info(
            "Initializing replay buffer by collecting experience for %d steps with "
            "a random policy.",
            initial_collect_steps,
        )
        for _ in range(initial_collect_steps):
            data_collector.step(initial_collect_policy)
        data_collector.reset()
        logging.info("Replay buffer initial size: %d",
                     replay_buffer.num_frames())

        logging.info("Computing initial eval metrics")
        for task in [None] + eval_task_list:
            with utils.FixedTask(eval_tf_env, task):
                prefix = "Metrics" if task is None else "Metrics-%s" % str(
                    task)
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix=prefix,
                )
                metric_utils.log_metrics(eval_metrics)

        time_acc = 0
        env_time_acc = 0
        train_time_acc = 0
        env_steps_before = env_steps.result().numpy()

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train)

        logging.info("Starting training")
        for _ in range(num_iterations):
            start_time = time.time()
            data_collector.step()
            env_time_acc += time.time() - start_time
            train_time_start = time.time()
            for _ in range(train_steps_per_iteration):
                experience = replay_buffer.get_batch()
                train_loss = tf_agent.train(experience)
                total_loss = train_loss.loss
            train_time_acc += time.time() - train_time_start
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step.numpy(),
                             total_loss)

                combined_steps_per_sec = (env_steps.result().numpy() -
                                          env_steps_before) / time_acc
                train_steps_per_sec = (env_steps.result().numpy() -
                                       env_steps_before) / train_time_acc
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / env_time_acc
                logging.info(
                    "%.3f combined steps / sec: %.3f env steps/sec, %.3f train steps/sec",
                    combined_steps_per_sec,
                    env_steps_per_sec,
                    train_steps_per_sec,
                )
                tf.compat.v2.summary.scalar(
                    name="combined_steps_per_sec",
                    data=combined_steps_per_sec,
                    step=env_steps.result(),
                )
                tf.compat.v2.summary.scalar(
                    name="env_steps_per_sec",
                    data=env_steps_per_sec,
                    step=env_steps.result(),
                )
                tf.compat.v2.summary.scalar(
                    name="train_steps_per_sec",
                    data=train_steps_per_sec,
                    step=env_steps.result(),
                )
                time_acc = 0
                env_time_acc = 0
                train_time_acc = 0
                env_steps_before = env_steps.result().numpy()

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:

                for task in [None] + eval_task_list:
                    with utils.FixedTask(eval_tf_env, task):
                        prefix = "Metrics" if task is None else "Metrics-%s" % str(
                            task)
                        logging.info(prefix)
                        metric_utils.eager_compute(
                            eval_metrics,
                            eval_tf_env,
                            eval_policy,
                            num_episodes=num_eval_episodes,
                            train_step=global_step,
                            summary_writer=eval_summary_writer,
                            summary_prefix=prefix,
                        )
                        metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

        return train_loss
예제 #8
0
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    replay_buffer.add_batch(traj)


for _ in range(initial_collect_steps):
    collect_step(train_env, random_policy)

dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                   sample_batch_size=batch_size,
                                   num_steps=2).prefetch(3)
iterator = iter(dataset)

import time
t0 = time.time()

tf_agent.train = common.function(tf_agent.train)
tf_agent.train_step_counter.assign(0)
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):
    collect_step(train_env, tf_agent.collect_policy)
    experience, unused_info = next(iterator)
    train_loss = tf_agent.train(experience)
    step = tf_agent.train_step_counter.numpy()
    if step % log_interval == 0:
        print('step = {}: loss = {}'.format(step, train_loss.loss))
    if step % eval_interval == 0:
        avg_return = compute_avg_return(eval_env, tf_agent.policy,
                                        num_eval_episodes)
        # eval_env._env.envs[0]._env.render()
예제 #9
0
    def __init__(self,
                 policy: tf_policy.TFPolicy,
                 batch_size: Optional[int] = None,
                 use_nest_path_signatures: bool = True,
                 seed: Optional[types.Seed] = None,
                 train_step: Optional[tf.Variable] = None,
                 input_fn_and_spec: Optional[InputFnAndSpecType] = None,
                 metadata: Optional[Dict[Text, tf.Variable]] = None):
        """Initialize PolicySaver for  TF policy `policy`.

    Args:
      policy: A TF Policy.
      batch_size: The number of batch entries the policy will process at a time.
        This must be either `None` (unknown batch size) or a python integer.
      use_nest_path_signatures: SavedModel spec signatures will be created based
        on the sructure of the specs. Otherwise all specs must have unique
        names.
      seed: Random seed for the `policy.action` call, if any (this should
        usually be `None`, except for testing).
      train_step: Variable holding the train step for the policy. The value
        saved will be set at the time `saver.save` is called. If not provided,
        train_step defaults to -1. Note since the train step must be a variable
        it is not safe to create it directly in TF1 so in that case this is a
        required parameter.
      input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
        function that takes inputs according to tensor_spec and converts them to
        the `(time_step, policy_state)` tuple that is used as the input to the
        action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
        for the action signature. When `input_fn_and_spec is None`, the action
        signature takes as input `(time_step, policy_state)`.
      metadata: A dictionary of `tf.Variables` to be saved along with the
        policy.

    Raises:
      TypeError: If `policy` is not an instance of TFPolicy.
      TypeError: If `metadata` is not a dictionary of tf.Variables.
      ValueError: If use_nest_path_signatures is not used and any of the
        following `policy` specs are missing names, or the names collide:
        `policy.time_step_spec`, `policy.action_spec`,
        `policy.policy_state_spec`, `policy.info_spec`.
      ValueError: If `batch_size` is not either `None` or a python integer > 0.
    """
        if not isinstance(policy, tf_policy.TFPolicy):
            raise TypeError('policy is not a TFPolicy.  Saw: %s' %
                            type(policy))
        if (batch_size is not None
                and (not isinstance(batch_size, int) or batch_size < 1)):
            raise ValueError(
                'Expected batch_size == None or python int > 0, saw: %s' %
                (batch_size, ))

        action_fn_input_spec = (policy.time_step_spec,
                                policy.policy_state_spec)
        if use_nest_path_signatures:
            action_fn_input_spec = _rename_spec_with_nest_paths(
                action_fn_input_spec)
        else:
            _check_spec(action_fn_input_spec)

        # Make a shallow copy as we'll be making some changes in-place.
        saved_policy = tf.Module()
        saved_policy.collect_data_spec = copy.copy(policy.collect_data_spec)
        saved_policy.policy_state_spec = copy.copy(policy.policy_state_spec)

        if train_step is None:
            if not common.has_eager_been_enabled():
                raise ValueError('train_step is required in TF1 and must be a '
                                 '`tf.Variable`: %s' % train_step)
            train_step = tf.Variable(
                -1,
                trainable=False,
                dtype=tf.int64,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                shape=())
        elif not isinstance(train_step, tf.Variable):
            raise ValueError('train_step must be a TensorFlow variable: %s' %
                             train_step)

        # We will need the train step for the Checkpoint object.
        self._train_step = train_step
        saved_policy.train_step = self._train_step

        self._metadata = metadata or {}
        for key, value in self._metadata.items():
            if not isinstance(key, str):
                raise TypeError('Keys of metadata must be strings: %s' % key)
            if not isinstance(value, tf.Variable):
                raise TypeError('Values of metadata must be tf.Variable: %s' %
                                value)
        saved_policy.metadata = self._metadata

        if batch_size is None:
            get_initial_state_fn = policy.get_initial_state
            get_initial_state_input_specs = (tf.TensorSpec(
                dtype=tf.int32, shape=(), name='batch_size'), )
        else:
            get_initial_state_fn = functools.partial(policy.get_initial_state,
                                                     batch_size=batch_size)
            get_initial_state_input_specs = ()

        get_initial_state_fn = common.function()(get_initial_state_fn)

        original_action_fn = policy.action

        if seed is not None:

            def action_fn(time_step, policy_state):
                return original_action_fn(time_step, policy_state, seed=seed)
        else:
            action_fn = original_action_fn

        def distribution_fn(time_step, policy_state):
            """Wrapper for policy.distribution() in the SavedModel."""
            try:
                outs = policy.distribution(time_step=time_step,
                                           policy_state=policy_state)
                return tf.nest.map_structure(_composite_distribution, outs)
            except (TypeError, NotImplementedError) as e:
                # TODO(b/156526399): Move this to just the policy.distribution() call
                # once tfp.experimental.as_composite() properly handles LinearOperator*
                # components as well as TransformedDistributions.
                logging.error(
                    'Could not serialize policy.distribution() for policy "%s". '
                    'Calling saved_model.distribution() will raise the '
                    'assertion error: %s', policy, e)

                @common.function()
                def _raise():
                    tf.Assert(False, [str(e)])
                    return ()

                outs = _raise()

        # We call get_concrete_function() for its side effect: to ensure the proper
        # ConcreteFunction is stored in the SavedModel.
        get_initial_state_fn.get_concrete_function(
            *get_initial_state_input_specs)

        train_step_fn = common.function(
            lambda: saved_policy.train_step).get_concrete_function()
        get_metadata_fn = common.function(
            lambda: saved_policy.metadata).get_concrete_function()

        def add_batch_dim(spec):
            return tf.TensorSpec(shape=tf.TensorShape(
                [batch_size]).concatenate(spec.shape),
                                 name=spec.name,
                                 dtype=spec.dtype)

        batched_time_step_spec = tf.nest.map_structure(add_batch_dim,
                                                       policy.time_step_spec)
        batched_policy_state_spec = tf.nest.map_structure(
            add_batch_dim, policy.policy_state_spec)

        policy_step_spec = policy.policy_step_spec
        policy_state_spec = policy.policy_state_spec

        if use_nest_path_signatures:
            batched_time_step_spec = _rename_spec_with_nest_paths(
                batched_time_step_spec)
            batched_policy_state_spec = _rename_spec_with_nest_paths(
                batched_policy_state_spec)
            policy_step_spec = _rename_spec_with_nest_paths(policy_step_spec)
            policy_state_spec = _rename_spec_with_nest_paths(policy_state_spec)
        else:
            _check_spec(batched_time_step_spec)
            _check_spec(batched_policy_state_spec)
            _check_spec(policy_step_spec)
            _check_spec(policy_state_spec)

        if input_fn_and_spec is not None:
            # Store a signature based on input_fn_and_spec
            @common.function()
            def polymorphic_action_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return action_fn(*action_inputs)

            @common.function()
            def polymorphic_distribution_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return distribution_fn(*action_inputs)

            batched_input_spec = tf.nest.map_structure(add_batch_dim,
                                                       input_fn_and_spec[1])
            # We call get_concrete_function() for its side effect: to ensure the
            # proper ConcreteFunction is stored in the SavedModel.
            polymorphic_action_fn.get_concrete_function(
                example=batched_input_spec)
            polymorphic_distribution_fn.get_concrete_function(
                example=batched_input_spec)

            action_input_spec = (input_fn_and_spec[1], )

        else:
            action_input_spec = action_fn_input_spec
            if batched_policy_state_spec:
                # Store the signature with a required policy state spec
                polymorphic_action_fn = common.function()(action_fn)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)

                polymorphic_distribution_fn = common.function()(
                    distribution_fn)
                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
            else:
                # Create a polymorphic action_fn which you can call as
                #  restored.action(time_step)
                # or
                #  restored.action(time_step, ())
                # (without retracing the inner action twice)
                @common.function()
                def polymorphic_action_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return action_fn(time_step, policy_state)

                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

                @common.function()
                def polymorphic_distribution_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return distribution_fn(time_step, policy_state)

                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

        signatures = {
            # CompositeTensors aren't well supported by old-style signature
            # mechanisms, so we do not have a signature for policy.distribution.
            'action':
            _function_with_flat_signature(polymorphic_action_fn,
                                          input_specs=action_input_spec,
                                          output_spec=policy_step_spec,
                                          include_batch_dimension=True,
                                          batch_size=batch_size),
            'get_initial_state':
            _function_with_flat_signature(
                get_initial_state_fn,
                input_specs=get_initial_state_input_specs,
                output_spec=policy_state_spec,
                include_batch_dimension=False),
            'get_train_step':
            _function_with_flat_signature(train_step_fn,
                                          input_specs=(),
                                          output_spec=train_step.dtype,
                                          include_batch_dimension=False),
            'get_metadata':
            _function_with_flat_signature(get_metadata_fn,
                                          input_specs=(),
                                          output_spec=tf.nest.map_structure(
                                              lambda v: v.dtype,
                                              self._metadata),
                                          include_batch_dimension=False),
        }

        saved_policy.action = polymorphic_action_fn
        saved_policy.distribution = polymorphic_distribution_fn
        saved_policy.get_initial_state = get_initial_state_fn
        saved_policy.get_train_step = train_step_fn
        saved_policy.get_metadata = get_metadata_fn
        # Adding variables as an attribute to facilitate updating them.
        saved_policy.model_variables = policy.variables()

        # TODO(b/156779400): Move to a public API for accessing all trackable leaf
        # objects (once it's available).  For now, we have no other way of tracking
        # objects like Tables, Vocabulary files, etc.
        try:
            saved_policy._all_assets = policy._unconditional_checkpoint_dependencies  # pylint: disable=protected-access
        except AttributeError as e:
            if '_self_unconditional' in str(e):
                logging.warn(
                    'Unable to capture all trackable objects in policy "%s".  This '
                    'may be okay.  Error: %s', policy, e)
            else:
                raise e

        self._policy = saved_policy
        self._signatures = signatures
        self._action_input_spec = action_input_spec
        self._policy_step_spec = policy_step_spec
        self._policy_state_spec = policy_state_spec
예제 #10
0
def train_agent(tf_agent, train_env, eval_env, num_iterations,
                num_eval_episodes, collect_episodes_per_iteration, v_n,
                model_W_train_epochs, model_sml_train_epochs,
                replay_buffer_capacity, optimizer, eval_interval,
                curiosity_interval, W, sml, actor_net, _run,
                w_max_dataset_size):
    """Train a tf.agent with sparse model."""

    decoder_layer_agent = actor_net.layers[0].layers[
        0]  # taking the copied layer with actual weights

    # sml initializes D properly
    decoder_layer_agent.set_weights([sml.D.numpy().T])

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_capacity)

    curiosity_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=1000000)  # should never overflow

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Reset the train step
    tf_agent.train_step_counter.assign(0)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, tf_agent.policy,
                                    num_eval_episodes)
    train_avg_return = compute_avg_return(eval_env, tf_agent.collect_policy,
                                          num_eval_episodes)
    returns = [avg_return]
    train_returns = [train_avg_return]

    with tqdm(total=num_iterations, desc="Iterations") as pbar:
        for iteration in range(num_iterations):

            # Collect a few episodes using collect_policy and save to the replay buffer.
            collect_episode(train_env, tf_agent.collect_policy,
                            collect_episodes_per_iteration,
                            [replay_buffer, curiosity_replay_buffer])

            # Use data from the buffer and update the agent's network.
            experience = replay_buffer.gather_all()
            train_loss = tf_agent.train(experience)
            replay_buffer.clear()

            #print("Agent train step")

            step = tf_agent.train_step_counter.numpy()

            _run.log_scalar("agent.train_loss", train_loss.loss.numpy(),
                            iteration)

            if step % eval_interval == 0:
                avg_return = compute_avg_return(eval_env, tf_agent.policy,
                                                num_eval_episodes)
                train_avg_return = compute_avg_return(train_env,
                                                      tf_agent.collect_policy,
                                                      num_eval_episodes)
                pbar.set_postfix(train_c_return=train_avg_return,
                                 eval_return=avg_return,
                                 agent_loss=train_loss.loss.numpy())
                returns.append(avg_return)
                train_returns.append(train_avg_return)
                _run.log_scalar("agent.eval_return", avg_return, iteration)
                _run.log_scalar("agent.train_return", train_avg_return,
                                iteration)

            if step % curiosity_interval == 0 and len:
                #clear_output()
                xs, ys = buffer_to_dataset(curiosity_replay_buffer, v_n)

                _run.log_scalar("causal.total_dataset_size", len(xs),
                                iteration)

                # prevent the dataset from growing in an unbounded way
                if len(xs) > w_max_dataset_size:
                    idxes = np.random.choice(range(len(xs)),
                                             w_max_dataset_size,
                                             replace=False)
                    xs = np.array(xs)[idxes]
                    ys = np.array(ys)[idxes]

                _run.log_scalar("causal.used_dataset_size", len(xs), iteration)

                # fitting on observational data...
                losses = W.fit(xs=xs, ys=ys, epochs=model_W_train_epochs)
                for l in losses:
                    _run.log_scalar("W.fit", l)
                #W.plot_loss()

                # setting weights from the agent to the model...
                sml.D.assign(decoder_layer_agent.get_weights()[0].T)

                # setting the new observation transition matrix
                sml.set_WoWa(*W.get_Wo_Wa())

                def sml_callback(loss):
                    for k, v in loss.items():
                        _run.log_scalar("sml.%s" % k, v)

                # fitting the SML model
                sml.fit(epochs=model_sml_train_epochs,
                        loss_callback=sml_callback)

                # setting weights from the model to the agent...
                decoder_layer_agent.set_weights([sml.D.numpy().T])

                #agent_replay_buffer.clear()
                # observations are actually the same

                #print("Model train step")
            pbar.update(1)

    return returns, train_returns
def test(binance, model):

    symbol = "BTCUSDT"

    df = pd.read_csv("..\\Data\\" + symbol + "_data.csv",
                     index_col=0,
                     parse_dates=True)
    df = make_dataset.make_reinforcement_dataset(df)

    train_env_py = TradingEnv(df)
    train_env_py.set_init_balance(1000)
    eval_env_py = TradingEnv(df)
    eval_env_py.set_init_balance(1000)
    #train_env_py = suite_gym.load('CartPole-v0')
    #eval_env_py = suite_gym.load('CartPole-v0')

    train_env = tf_py_environment.TFPyEnvironment(train_env_py)
    eval_env = tf_py_environment.TFPyEnvironment(eval_env_py)

    #q_net = q_rnn_network.QRnnNetwork(
    #    train_env.observation_spec(),
    #    train_env.action_spec(),
    #    input_fc_layer_params=(128,64,16),
    #    output_fc_layer_params=(128,64,16),
    #    lstm_size=(128,64,16))

    q_net = q_network.QNetwork(train_env.observation_spec(),
                               train_env.action_spec(),
                               fc_layer_params=(256, 128, 64, 32, 16))

    global_step = tf.Variable(0, name="global_step", trainable=False)

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)

    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_step,
        epsilon_greedy=0.2)

    agent.initialize()

    eval_policy = agent.policy
    collect_policy = agent.collect_policy

    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                    train_env.action_spec())

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_max_length)

    collect_data(train_env, random_policy, replay_buffer, steps=10000)

    policy_checkpointer = common.Checkpointer(
        ckpt_dir="ReinforcementLearnData/Checkpoint",
        agent=agent,
        policy=agent.policy,
        replay_buffer=replay_buffer,
        global_step=global_step)
    policy_checkpointer.initialize_or_restore()
    tf_policy_saver = policy_saver.PolicySaver(agent.policy)

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)

    iterator = iter(dataset)

    print("1:Train 2:Evaluate 3:Simulation")
    mode = input(">")
    if mode == "1":
        pass
    elif mode == "2":
        evaluate(agent, eval_env)
        return
    elif mode == "3":
        simulation(binance, agent)
        return

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    agent.train = common.function(agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    returns = [avg_return]

    for _ in range(num_iterations):

        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        step = agent.train_step_counter.numpy()

        if step % log_interval == 0:
            print("step = {0}: loss = {1}".format(step, train_loss))

        if step % eval_interval == 0:
            avg_return = compute_avg_return(eval_env, agent.policy,
                                            num_eval_episodes)
            print("step = {0}: Average Return = {1}".format(step, avg_return))
            returns.append(avg_return)

    #save agent
    policy_checkpointer.save(global_step)
    tf_policy_saver.save("ReinforcementLearnData/Policy")

    x = range(0, num_iterations + 1, eval_interval)
    plt.plot(x, returns)
    plt.ylabel('Average Return')
    plt.xlabel('Iterations')
    plt.show()

    print("Result:")
    print(compute_avg_return(eval_env, agent.policy, num_eval_episodes))
예제 #12
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    env_load_fn=suite_mujoco.load,
    random_seed=None,
    # TODO(b/127576522): rename to policy_fc_layers.
    actor_fc_layers=(200, 100),
    value_fc_layers=(200, 100),
    use_rnns=False,
    # Params for collect
    num_environment_steps=25000000,
    collect_episodes_per_iteration=30,
    num_parallel_environments=30,
    replay_buffer_capacity=1001,  # Per-environment
    # Params for train
    num_epochs=25,
    learning_rate=1e-3,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=500,
    # Params for summaries and logging
    train_checkpoint_interval=500,
    policy_checkpoint_interval=500,
    log_interval=50,
    summary_interval=50,
    summaries_flush_secs=1,
    use_tf_functions=True,
    debug_summaries=False,
    summarize_grads_and_vars=False):
  """A simple train and eval for PPO."""
  if root_dir is None:
    raise AttributeError('train_eval requires a root_dir.')

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)
    eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments))
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    if use_rnns:
      actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          input_fc_layer_params=actor_fc_layers,
          output_fc_layer_params=None)
      value_net = value_rnn_network.ValueRnnNetwork(
          tf_env.observation_spec(),
          input_fc_layer_params=value_fc_layers,
          output_fc_layer_params=None)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers,
          activation_fn=tf.keras.activations.tanh)
      value_net = value_network.ValueNetwork(
          tf_env.observation_spec(),
          fc_layer_params=value_fc_layers,
          activation_fn=tf.keras.activations.tanh)

    tf_agent = ppo_clip_agent.PPOClipAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        optimizer,
        actor_net=actor_net,
        value_net=value_net,
        entropy_regularization=0.0,
        importance_ratio_clipping=0.2,
        normalize_observations=False,
        normalize_rewards=False,
        use_gae=True,
        num_epochs=num_epochs,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]

    train_metrics = step_metrics + [
        tf_metrics.AverageReturnMetric(
            batch_size=num_parallel_environments),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=num_parallel_environments),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    saved_model = policy_saver.PolicySaver(
        eval_policy, train_step=global_step)

    train_checkpointer.initialize_or_restore()

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration)

    def train_step():
      trajectories = replay_buffer.gather_all()
      return tf_agent.train(experience=trajectories)

    if use_tf_functions:
      # TODO(b/123828980): Enable once the cause for slowdown was identified.
      collect_driver.run = common.function(collect_driver.run, autograph=False)
      tf_agent.train = common.function(tf_agent.train, autograph=False)
      train_step = common.function(train_step)

    collect_time = 0
    train_time = 0
    timed_at_step = global_step.numpy()

    while environment_steps_metric.result() < num_environment_steps:
      global_step_val = global_step.numpy()
      if global_step_val % eval_interval == 0:
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )

      start_time = time.time()
      collect_driver.run()
      collect_time += time.time() - start_time

      start_time = time.time()
      total_loss, _ = train_step()
      replay_buffer.clear()
      train_time += time.time() - start_time

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics)

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, total_loss)
        steps_per_sec = (
            (global_step_val - timed_at_step) / (collect_time + train_time))
        logging.info('%.3f steps/sec', steps_per_sec)
        logging.info('collect_time = %.3f, train_time = %.3f', collect_time,
                     train_time)
        with tf.compat.v2.summary.record_if(True):
          tf.compat.v2.summary.scalar(
              name='global_steps_per_sec', data=steps_per_sec, step=global_step)

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)
          saved_model_path = os.path.join(
              saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9))
          saved_model.save(saved_model_path)

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0

    # One final eval before exiting.
    metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
예제 #13
0
def train(dummy_env, tf_agent):

    #parallel_env = ParallelPyEnvironment([SpinQubitEnv(en_configs[i]) for i in range(6)])
    #train_env = tf_py_environment.TFPyEnvironment(parallel_env)
    #gammavals = np.linspace(.01,.99, 5000)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=dummy_env.batch_size,
        max_length=replay_buffer_capacity)

    tf_agent.train = common.function(tf_agent.train)

    # Reset the train step
    tf_agent.train_step_counter.assign(0)

    avg_return = compute_avg_return(dummy_env, tf_agent.policy, 10)
    returns = [avg_return]

    for x in np.linspace(0, np.pi, 10):
        for y in np.linspace(0, 2 * np.pi, 10):
            for target_unitary in [sigmax()]:
                for relative_detuning in [.1, 10]:
                    for _ in range(250):
                        # for _ in range(num_iterations):
                        # num_iterations
                        # Collect a few episodes using collect_policy and save to the replay buffer.
                        #train_env,eval_env = get_environments(x, y)

                        train_env = get_tf_environment(x, y, 0.1,
                                                       target_unitary, 1,
                                                       relative_detuning)
                        collect_episode(train_env, replay_buffer,
                                        tf_agent.collect_policy,
                                        collect_episodes_per_iteration)

                        # Use data from the buffer and update the agent's network.
                        experience = replay_buffer.gather_all()
                        train_loss = tf_agent.train(experience)
                        replay_buffer.clear()

                        step = tf_agent.train_step_counter.numpy()

                        if step % 100 == 0:
                            print('step = {0}: loss = {1}'.format(
                                step, train_loss.loss))

                        if True or step % eval_interval == 0:

                            #eval_py_env = SpinQubitEnv(environment_configs)
                            ##eval_gamma = [random.uniform(0,0.5),random.uniform(0,0.5)]
                            # eval_py_env.set_gamma(eval_gamma)
                            #eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
                            eval_env = get_tf_environment(
                                x, y, 0.1, target_unitary, 1,
                                relative_detuning)
                            avg_return = compute_avg_return(
                                eval_env, tf_agent.policy, num_eval_episodes)
                            print('step = {0}: Average Return = {1}'.format(
                                step, avg_return))
                            returns.append(avg_return)
    return (step, train_loss, returns, fidelity)
예제 #14
0
파일: ppo_toy.py 프로젝트: w121211/agents
def train():
    num_parallel_environments = 2
    collect_episodes_per_iteration = 2  # 30

    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment([
            lambda: tf_py_environment.TFPyEnvironment(
                suite_gym.wrap_env(RectEnv()))
        ] * num_parallel_environments))

    print(tf_env.time_step_spec())
    print(tf_env.action_spec())
    print(tf_env.observation_spec())

    preprocessing_layers = {
        'target':
        tf.keras.models.Sequential([
            # tf.keras.applications.MobileNetV2(
            #     input_shape=(64, 64, 1), include_top=False, weights=None),
            tf.keras.layers.Conv2D(1, 6),
            tf.keras.layers.Flatten()
        ]),
        'canvas':
        tf.keras.models.Sequential([
            # tf.keras.applications.MobileNetV2(
            #     input_shape=(64, 64, 1), include_top=False, weights=None),
            tf.keras.layers.Conv2D(1, 6),
            tf.keras.layers.Flatten()
        ]),
        'coord':
        tf.keras.models.Sequential(
            [tf.keras.layers.Dense(5),
             tf.keras.layers.Flatten()])
    }
    preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner)
    value_net = value_network.ValueNetwork(
        tf_env.observation_spec(),
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner)

    tf_agent = ppo_agent.PPOAgent(tf_env.time_step_spec(),
                                  tf_env.action_spec(),
                                  tf.compat.v1.train.AdamOptimizer(),
                                  actor_net=actor_net,
                                  value_net=value_net,
                                  normalize_observations=False,
                                  use_gae=False)
    tf_agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec, batch_size=num_parallel_environments)
    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        tf_agent.collect_policy,
        observers=[replay_buffer.add_batch],
        num_episodes=collect_episodes_per_iteration)

    # print(tf_agent.collect_data_spec)

    def train_step():
        trajectories = replay_buffer.gather_all()
        return tf_agent.train(experience=trajectories)

    collect_driver.run = common.function(collect_driver.run, autograph=False)
    # tf_agent.train = common.function(tf_agent.train, autograph=False)
    # train_step = common.function(train_step)

    # for _ in range(10):
    collect_driver.run()
예제 #15
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy

        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run()

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()

        train_fn = common.function(tf_agent.train)
        train_op = train_fn(experience=trajectories,
                            train_step_counter=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(b/126239733) Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                    global_step_val = global_step_call()
                time_acc += time.time() - start_time

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True)
예제 #16
0
def train_eval(
    root_dir,
    environment_name="broken_reacher",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    real_initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    real_collect_interval=10,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    classifier_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    train_on_real=False,
    delta_r_warmup=0,
    random_seed=0,
    checkpoint_dir=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    if environment_name == "broken_reacher":
        get_env_fn = darc_envs.get_broken_reacher_env
    elif environment_name == "half_cheetah_obstacle":
        get_env_fn = darc_envs.get_half_cheetah_direction_env
    elif environment_name.startswith("broken_joint"):
        base_name = environment_name.split("broken_joint_")[1]
        get_env_fn = functools.partial(darc_envs.get_broken_joint_env,
                                       env_name=base_name)
    elif environment_name.startswith("falling"):
        base_name = environment_name.split("falling_")[1]
        get_env_fn = functools.partial(darc_envs.get_falling_env,
                                       env_name=base_name)
    else:
        raise NotImplementedError("Unknown environment: %s" % environment_name)

    eval_name_list = ["sim", "real"]
    eval_env_list = [get_env_fn(mode) for mode in eval_name_list]

    eval_metrics_list = []
    for name in eval_name_list:
        eval_metrics_list.append([
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           name="AverageReturn_%s" % name),
        ])

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env_real = get_env_fn("real")
        if train_on_real:
            tf_env = get_env_fn("real")
        else:
            tf_env = get_env_fn("sim")

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=(
                tanh_normal_projection_network.TanhNormalProjectionNetwork),
        )
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer="glorot_uniform",
            last_kernel_initializer="glorot_uniform",
        )

        classifier = classifiers.build_classifier(observation_spec,
                                                  action_spec)

        tf_agent = darc_agent.DarcAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            classifier=classifier,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            classifier_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=classifier_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        replay_observer = [replay_buffer.add_batch]

        real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        real_replay_observer = [real_replay_buffer.add_batch]

        sim_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnSim",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthSim",
            ),
        ]
        real_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnReal",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthReal",
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(
                sim_train_metrics + real_train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "replay_buffer"),
            max_to_keep=1,
            replay_buffer=(replay_buffer, real_replay_buffer),
        )

        if checkpoint_dir is not None:
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
            assert checkpoint_path is not None
            train_checkpointer._load_status = train_checkpointer._checkpoint.restore(  # pylint: disable=protected-access
                checkpoint_path)
            train_checkpointer._load_status.initialize_or_restore()  # pylint: disable=protected-access
        else:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if replay_buffer.num_frames() == 0:
            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + sim_train_metrics,
                num_steps=initial_collect_steps,
            )
            real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env_real,
                initial_collect_policy,
                observers=real_replay_observer + real_train_metrics,
                num_steps=real_initial_collect_steps,
            )

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + sim_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        real_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env_real,
            collect_policy,
            observers=real_replay_observer + real_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            real_initial_collect_driver.run = common.function(
                real_initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            real_collect_driver.run = common.function(real_collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if replay_buffer.num_frames() == 0:
            logging.info(
                "Initializing replay buffer by collecting experience for %d steps with "
                "a random policy.",
                initial_collect_steps,
            )
            initial_collect_driver.run()
            real_initial_collect_driver.run()

        for eval_name, eval_env, eval_metrics in zip(eval_name_list,
                                                     eval_env_list,
                                                     eval_metrics_list):
            metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix="Metrics-%s" % eval_name,
            )
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        real_time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = (replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))
        real_dataset = (real_replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))

        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)
        real_iterator = iter(real_dataset)

        def train_step():
            experience, _ = next(iterator)
            real_experience, _ = next(real_iterator)
            return tf_agent.train(experience, real_experience=real_experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            assert not policy_state  # We expect policy_state == ().
            if (global_step.numpy() % real_collect_interval == 0
                    and global_step.numpy() >= delta_r_warmup):
                real_time_step, policy_state = real_collect_driver.run(
                    time_step=real_time_step,
                    policy_state=policy_state,
                )

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info("%.3f steps/sec", steps_per_sec)
                tf.compat.v2.summary.scalar(name="global_steps_per_sec",
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in sim_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=sim_train_metrics[:2])
            for train_metric in real_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=real_train_metrics[:2])

            if global_step_val % eval_interval == 0:
                for eval_name, eval_env, eval_metrics in zip(
                        eval_name_list, eval_env_list, eval_metrics_list):
                    metric_utils.eager_compute(
                        eval_metrics,
                        eval_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix="Metrics-%s" % eval_name,
                    )
                    metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
예제 #17
0
def train_eval(
        root_dir,
        env_name=ENV_NAME,
        num_iterations=ITERATIONS,
        fc_layer_params=LAYER_PARAMETERS,
        # Parameters for collect
        initial_collect_steps=COLLECT_STEPS,
        collect_steps_per_iteration=COLLECT_STEPS,
        epsilon_greedy=GREEDY,
        replay_buffer_capacity=BUFFER,
        # Parameters for target update
        target_update_tau=TAU,
        target_update_period=UPDATE_PERIOD,
        # Parameters for train
        train_steps_per_iteration=TRAIN_ITERATIONS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARN_RATE,
        gamma=GAMMA,
        reward_scale_factor=SCALE,
        gradient_clipping=GRADIENT,
        # Parameters for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Parameters for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        agent_class=dqn_agent.DqnAgent,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        replay_observer = [replay_buffer.add_batch]
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        train_op = common.function(tf_agent.train)(experience=experience)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
                tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            global_step_call = sess.make_callable(global_step)
            train_step_call = sess.make_callable([train_op, summary_ops])

            timed_at_step = global_step_call()
            collect_time = 0
            train_time = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                collect_time += time.time() - start_time
                start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                train_time += time.time() - start_time

                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info('%.3f steps/sec', steps_per_sec)
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
예제 #18
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        log_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # Note this is a python environment.
    env = batched_py_environment.BatchedPyEnvironment(
        [suite_gym.load(env_name)])
    eval_py_env = suite_gym.load(env_name)

    # Convert specs to BoundedTensorSpec.
    action_spec = tensor_spec.from_spec(env.action_spec())
    observation_spec = tensor_spec.from_spec(env.observation_spec())
    time_step_spec = ts.time_step_spec(observation_spec)

    q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()),
                               tensor_spec.from_spec(env.action_spec()),
                               fc_layer_params=fc_layer_params)

    # The agent must be in graph.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy)
    greedy_policy = py_tf_policy.PyTFPolicy(agent.policy)
    random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                                    env.action_spec())

    # Python replay buffer.
    replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=replay_buffer_capacity,
        data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec))

    time_step = env.reset()

    # Initialize the replay buffer with some transitions. We use the random
    # policy to initialize the replay buffer to make sure we get a good
    # distribution of actions.
    for _ in range(initial_collect_steps):
        time_step = collect_step(env, time_step, random_policy, replay_buffer)

    # TODO(b/112041045) Use global_step as counter.
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step)

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=agent.policy,
                                              global_step=global_step)

    ds = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                  num_steps=n_step_update + 1)
    ds = ds.prefetch(4)
    itr = tf.compat.v1.data.make_initializable_iterator(ds)

    experience = itr.get_next()

    train_op = common.function(agent.train)(experience)

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
        for eval_metric in eval_metrics:
            eval_metric.tf_summaries(train_step=global_step)

    with tf.compat.v1.Session() as session:
        train_checkpointer.initialize_or_restore(session)
        common.initialize_uninitialized_variables(session)
        session.run(itr.initializer)
        # Copy critic network values to the target critic network.
        session.run(agent.initialize())
        train = session.make_callable(train_op)
        global_step_call = session.make_callable(global_step)
        session.run(train_summary_writer.init())
        session.run(eval_summary_writer.init())

        # Compute initial evaluation metrics.
        global_step_val = global_step_call()
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            greedy_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            log=True,
            callback=eval_metrics_callback,
        )

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0
        steps_per_second_ph = tf.compat.v1.placeholder(tf.float32,
                                                       shape=(),
                                                       name='steps_per_sec_ph')
        steps_per_second_summary = tf.compat.v2.summary.scalar(
            name='global_steps_per_sec',
            data=steps_per_second_ph,
            step=global_step)

        for _ in range(num_iterations):
            start_time = time.time()
            for _ in range(collect_steps_per_iteration):
                time_step = collect_step(env, time_step, collect_policy,
                                         replay_buffer)
            collect_time += time.time() - start_time
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                loss = train()
            train_time += time.time() - start_time
            global_step_val = global_step_call()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             loss.loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                session.run(steps_per_second_summary,
                            feed_dict={steps_per_second_ph: steps_per_sec})
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    '%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % eval_interval == 0:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    greedy_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    log=True,
                    callback=eval_metrics_callback,
                )
                # Reset timing to avoid counting eval time.
                timed_at_step = global_step_val
                start_time = time.time()
예제 #19
0
    def __init__(self,
                 policy,
                 batch_size=None,
                 use_nest_path_signatures=True,
                 seed=None,
                 train_step=None,
                 input_fn_and_spec=None):
        """Initialize PolicySaver for  TF policy `policy`.

    Args:
      policy: A TF Policy.
      batch_size: The number of batch entries the policy will process at a time.
        This must be either `None` (unknown batch size) or a python integer.
      use_nest_path_signatures: SavedModel spec signatures will be created based
        on the sructure of the specs. Otherwise all specs must have unique
        names.
      seed: Random seed for the `policy.action` call, if any (this should
        usually be `None`, except for testing).
      train_step: Variable holding the train step for the policy. The value
        saved will be set at the time `saver.save` is called. If not provided,
        train_step defaults to -1. Note since the train step must be a variable
        it is not safe to create it directly in TF1 so in that case this is a
        required parameter.
      input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
        function that takes inputs according to tensor_spec and converts them to
        the `(time_step, policy_state)` tuple that is used as the input to the
        action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
        for the action signature. When `input_fn_and_spec is None`, the action
        signature takes as input `(time_step, policy_state)`.

    Raises:
      TypeError: If `policy` is not an instance of TFPolicy.
      ValueError: If use_nest_path_signatures is not used and any of the
        following `policy` specs are missing names, or the names collide:
        `policy.time_step_spec`, `policy.action_spec`,
        `policy.policy_state_spec`, `policy.info_spec`.
      ValueError: If `batch_size` is not either `None` or a python integer > 0.
    """
        if not isinstance(policy, tf_policy.TFPolicy):
            raise TypeError('policy is not a TFPolicy.  Saw: %s' %
                            type(policy))
        if (batch_size is not None
                and (not isinstance(batch_size, int) or batch_size < 1)):
            raise ValueError(
                'Expected batch_size == None or python int > 0, saw: %s' %
                (batch_size, ))

        action_fn_input_spec = (policy.time_step_spec,
                                policy.policy_state_spec)
        if use_nest_path_signatures:
            action_fn_input_spec = _rename_spec_with_nest_paths(
                action_fn_input_spec)
        else:
            _check_spec(action_fn_input_spec)

        # Make a shallow copy as we'll be making some changes in-place.
        policy = copy.copy(policy)
        if train_step is None:
            if not common.has_eager_been_enabled():
                raise ValueError('train_step is required in TF1 and must be a '
                                 '`tf.Variable`: %s' % train_step)
            train_step = tf.Variable(
                -1,
                trainable=False,
                dtype=tf.int64,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                shape=())
        elif not isinstance(train_step, tf.Variable):
            raise ValueError('train_step must be a TensorFlow variable: %s' %
                             train_step)
        policy.train_step = train_step

        # We will need the train step for the Checkpoint object.
        self._train_step = train_step

        if batch_size is None:
            get_initial_state_fn = policy.get_initial_state
            get_initial_state_input_specs = (tf.TensorSpec(
                dtype=tf.int32, shape=(), name='batch_size'), )
        else:
            get_initial_state_fn = functools.partial(policy.get_initial_state,
                                                     batch_size=batch_size)
            get_initial_state_input_specs = ()

        get_initial_state_fn = common.function()(get_initial_state_fn)

        original_action_fn = policy.action
        if seed is not None:

            def action_fn(time_step, policy_state):
                return original_action_fn(time_step, policy_state, seed=seed)
        else:
            action_fn = original_action_fn

        # We call get_concrete_function() for its side effect.
        get_initial_state_fn.get_concrete_function(
            *get_initial_state_input_specs)

        train_step_fn = common.function(
            lambda: policy.train_step).get_concrete_function()

        action_fn = common.function()(action_fn)

        def add_batch_dim(spec):
            return tf.TensorSpec(shape=tf.TensorShape(
                [batch_size]).concatenate(spec.shape),
                                 name=spec.name,
                                 dtype=spec.dtype)

        batched_time_step_spec = tf.nest.map_structure(add_batch_dim,
                                                       policy.time_step_spec)
        batched_policy_state_spec = tf.nest.map_structure(
            add_batch_dim, policy.policy_state_spec)

        policy_step_spec = policy.policy_step_spec
        policy_state_spec = policy.policy_state_spec

        if use_nest_path_signatures:
            batched_time_step_spec = _rename_spec_with_nest_paths(
                batched_time_step_spec)
            batched_policy_state_spec = _rename_spec_with_nest_paths(
                batched_policy_state_spec)
            policy_step_spec = _rename_spec_with_nest_paths(policy_step_spec)
            policy_state_spec = _rename_spec_with_nest_paths(policy_state_spec)
        else:
            _check_spec(batched_time_step_spec)
            _check_spec(batched_policy_state_spec)
            _check_spec(policy_step_spec)
            _check_spec(policy_state_spec)

        if input_fn_and_spec is not None:
            # Store a signature based on input_fn_and_spec
            @common.function()
            def polymorphic_action_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return action_fn(*action_inputs)

            batched_input_spec = tf.nest.map_structure(add_batch_dim,
                                                       input_fn_and_spec[1])
            # We call get_concrete_function() for its side effect.
            polymorphic_action_fn.get_concrete_function(
                example=batched_input_spec)

            action_input_spec = (input_fn_and_spec[1], )

        else:
            action_input_spec = action_fn_input_spec
            if batched_policy_state_spec:
                # Store the signature with a required policy state spec
                polymorphic_action_fn = action_fn
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
            else:
                # Create a polymorphic action_fn which you can call as
                #  restored.action(time_step)
                # or
                #  restored.action(time_step, ())
                # (without retracing the inner action twice)
                @common.function()
                def polymorphic_action_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return action_fn(time_step, policy_state)

                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

        signatures = {
            'action':
            _function_with_flat_signature(polymorphic_action_fn,
                                          input_specs=action_input_spec,
                                          output_spec=policy_step_spec,
                                          include_batch_dimension=True,
                                          batch_size=batch_size),
            'get_initial_state':
            _function_with_flat_signature(
                get_initial_state_fn,
                input_specs=get_initial_state_input_specs,
                output_spec=policy_state_spec,
                include_batch_dimension=False),
            'get_train_step':
            _function_with_flat_signature(train_step_fn,
                                          input_specs=(),
                                          output_spec=train_step.dtype,
                                          include_batch_dimension=False),
        }

        policy.action = polymorphic_action_fn
        policy.get_initial_state = get_initial_state_fn
        policy.get_train_step = train_step_fn
        # Adding variables as an attribute to facilitate updating them.
        policy.model_variables = policy.variables()

        self._policy = policy
        self._signatures = signatures
        self._action_input_spec = action_input_spec
        self._policy_step_spec = policy_step_spec
        self._policy_state_spec = policy_state_spec
예제 #20
0
def train_level(level,
                consecutive_wins_flag=5,
                collect_random_steps=True,
                max_iterations=num_iterations):
    """
    create DQN agent to train a level of the game
    :param level: level of the game
    :param consecutive_wins_flag: number of consecutive wins in evaluation
    signifying the training is done
    :param collect_random_steps: whether to collect random steps at the beginning,
    always set to 'True' when the global step is 0.
    :param max_iterations: stop the training when it reaches the max iteration
    regardless of the result
    """
    global saving_time
    cells = query_level(level)
    size = len(cells)
    env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells))
    eval_env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells))

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    fc_layer_params = (neuron_num_mapper[size], )

    q_net = q_network.QNetwork(env.observation_spec()[0],
                               env.action_spec(),
                               fc_layer_params=fc_layer_params,
                               activation_fn=tf.keras.activations.relu)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DdqnAgent(
        env.time_step_spec(),
        env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_step,
        observation_and_action_constraint_splitter=GameEnv.
        obs_and_mask_splitter)
    agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=env.batch_size,
        max_length=replay_buffer_max_length)

    # drivers
    collect_driver = dynamic_step_driver.DynamicStepDriver(
        env,
        policy=agent.collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=collect_steps_per_iteration)

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        eval_env,
        policy=agent.policy,
        observers=eval_metrics,
        num_episodes=num_eval_episodes)

    # checkpointer of the replay buffer and policy
    train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        dir_path, 'trained_policies/train_lv{0}'.format(level)),
                                             max_to_keep=1,
                                             agent=agent,
                                             policy=agent.policy,
                                             global_step=global_step,
                                             replay_buffer=replay_buffer)

    # policy saver
    tf_policy_saver = policy_saver.PolicySaver(agent.policy)

    train_checkpointer.initialize_or_restore()

    # optimize by wrapping some of the code in a graph using TF function
    agent.train = common.function(agent.train)
    collect_driver.run = common.function(collect_driver.run)
    eval_driver.run = common.function(eval_driver.run)

    # collect initial replay data
    if collect_random_steps:
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec=env.time_step_spec(),
            action_spec=env.action_spec(),
            observation_and_action_constraint_splitter=GameEnv.
            obs_and_mask_splitter)

        dynamic_step_driver.DynamicStepDriver(
            env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # train the model until 5 consecutive evaluation have reward greater than 100
    consecutive_eval_win = 0
    train_iterations = 0
    while consecutive_eval_win < consecutive_wins_flag and train_iterations < max_iterations:
        collect_driver.run()

        for _ in range(collect_steps_per_iteration):
            experience, _ = next(iterator)
            train_loss = agent.train(experience).loss

        # evaluate the training at intervals
        step = global_step.numpy()
        if step % eval_interval == 0:
            eval_driver.run()
            average_return = eval_metrics[0].result().numpy()
            average_len = eval_metrics[1].result().numpy()
            print("level: {0} step: {1} AverageReturn: {2} AverageLen: {3}".
                  format(level, step, average_return, average_len))

            # evaluate consecutive wins
            if average_return > 10:
                consecutive_eval_win += 1
            else:
                consecutive_eval_win = 0

        if step % save_interval == 0:
            start = time.time()
            train_checkpointer.save(global_step=step)
            saving_time += time.time() - start

        train_iterations += 1

    # save the policy
    train_checkpointer.save(global_step=global_step.numpy())
    tf_policy_saver.save(
        os.path.join(dir_path, 'trained_policies/policy_lv{0}'.format(level)))
예제 #21
0
    # driver to collect observations for training the agent
    agent_collect_driver = dynamic_step_driver.DynamicStepDriver(
        train_env,
        agent_collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=hyperparams['collect_steps_per_iteration'])

    # Dataset generates trajectories with shape [BatchSizex2x...]
    agent_dataset = replay_buffer.as_dataset(
        num_parallel_calls=4,
        sample_batch_size=hyperparams['batch_size'],
        num_steps=2).prefetch(3)
    agent_iterator = iter(agent_dataset)

    agent.train = common.function(agent.train)
    agent_collect_driver.run = common.function(agent_collect_driver.run)

    # Reset the train step
    agent.train_step_counter.assign(0)
    step = agent.train_step_counter.numpy()

    # for collecting results and losses
    critic_loss = []
    actor_loss = []
    average_rewards = [0.]
    eval_results = dict()
    eval_results['params'] = hyperparams
    train_checkpointer = common.Checkpointer(
        ckpt_dir=hyperparams['checkpoint_dir'],
        max_to_keep=1,
예제 #22
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        train_sequence_length=20,
        critic_learning_rate=3e-4,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=_DEFAULT_REWARD_SCALE,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    if reward_scale_factor == _DEFAULT_REWARD_SCALE:
        # Use value recommended by https://arxiv.org/abs/1801.01290
        if env_name.startswith('Humanoid'):
            reward_scale_factor = 20.0
        else:
            reward_scale_factor = 5.0

    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=normal_projection_net)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size * num_parallel_environments,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
def train_agent(tf_agent,
                train_env,
                eval_env,
                num_iterations,
                returns,
                losses,
                collect_episodes_per_iteration,
                log_interval,
                eval_interval,
                policy_checkpoint_interval,
                replay_buffer_capacity,
                num_eval_episodes,
                direc,
                verbose=False):
    """
	Main training function, which optimizes an agent in a given metamdp environment
	inputs:
	tf_agent: either a REINFORCE or PPO agent, as created by load_reinforce_agent or load_ppo_agent. 
	Note that train_agent does not require num_epochs or actor/value/preprocessing layers as arguments, since 
	they are implicitly provided through tf_agent
	train_env: environment in which the agent will collect episodes to train on
	eval_env: environment in which the agent will collect episodes to monitor its performance (i.e., learning curve). It is apparently good practice to separate these environments even though they are both instances of the same object with the same settings. 
	Note that both train_env and eval_env should not be a MetaMDPEnv, but instead a tf_py_environment.TFPyEnvironment(), 
	which is a function that tf_agents provided to convert actions/observations in a gym environment to tensors. 
	returns, losses: arrays to which the return and loss will be appended at regular intervals. Generally I train with returns = [] and losses = []
	collect_episodes_per_iteration: number of episodes to collect per training iteration. Note that the algorithm will train on this data for multiple epochs.
	log_interval: how often the training algorithm should report the training loss. Note that this interval should be provided in training steps, which increments by num_epochs after each call to tf_agent.train(). Therefore, log_interval should be an integer multiple of num_epochs. 
	eval_interval: same for evaluation of expected returns. 
	policy_checkpoint_interval: same for regular dumps of the policy parameters
	replay_buffer_capacity: size if the buffer that train_agent stores episodes in. Note that the size is measured in steps, not episodes. For example, if replay_buffer_capacity is 1000 and each episode is 10 steps, this buffer will store 100 episodes. 
	num_eval_episodes: number of episodes to collect when evaluating the expected return. More episodes imply more accurate estimates but also take more time
	direc: directory in which output will be saved. If the directory does not exist, it will be created
	verbose: flag which controls the amount of output, default False
	"""

    if not os.path.exists(direc):
        os.mkdir(direc)

    with open(
            os.path.join(direc, 'results.txt'), 'w', buffering=1
    ) as f:  #buffering=1 causes the file to be flushed after every line
        print(inspect.getargvalues(inspect.currentframe()),
              file=f)  # this logs all argument values the log file f

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=train_env.batch_size,
            max_length=replay_buffer_capacity)
        #initializing the replay buffer

        if not policy_checkpoint_interval is None:
            policy_checkpointer = common.Checkpointer(
                ckpt_dir=os.path.join(direc, 'policies/'),
                policy=tf_agent.collect_policy)
            saved_model = policy_saver.PolicySaver(tf_agent.collect_policy)
            #initializing the objects needed to create regular dumps of the policy parameters

        tf_agent.train = common.function(tf_agent.train)

        # Set the training step counter to 0
        tf_agent.train_step_counter.assign(0)

        for _ in range(num_iterations):
            # Collect a few episodes using collect_policy and save to the replay buffer.
            if verbose:
                print("collecting " + str(collect_episodes_per_iteration) +
                      " episodes")
            collect_episode(replay_buffer, train_env, tf_agent.collect_policy,
                            collect_episodes_per_iteration, verbose)
            trajectories = replay_buffer.gather_all()
            # Use data from the buffer and update the agent's network.

            step = tf_agent.train_step_counter.numpy()

            if step % eval_interval == 0:
                avg_return, avg_length = compute_avg_return(
                    eval_env, tf_agent.collect_policy, num_eval_episodes)
                m = 'step = {0}: Average Collection Return = {1}'.format(
                    step, avg_return)
                print(m)
                print(m, file=f)
                m = 'step = {0}: Average Collection Length = {1}'.format(
                    step, avg_length)
                print(m)
                print(m, file=f)
                returns.append(avg_return)
                #log the average return and the average episode length (that is, the average number of observations that the agent takes before terminating

            if not policy_checkpoint_interval is None and step % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=step)
                saved_model_path = os.path.join(
                    direc, 'policies/policy_' + ('%d' % step).zfill(9))
                saved_model.save(saved_model_path)

            if verbose:
                print("training")
            train_loss = tf_agent.train(experience=trajectories)
            #Note that usually one would write
            #replay_buffer.clear()
            #to clear the buffer of episodes taken before this training step, but we keep it to
            #keep a rolling buffer of episodes. When adding to the buffer it automatically kicks older epsiodes out (afaik)

            if step % log_interval == 0:
                m = 'step = {0}: loss = {1}'.format(step, train_loss.loss)
                print(m)
                print(m, file=f)
                losses.append(train_loss.loss.numpy())

        return tf_agent, returns, losses
예제 #24
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_update_period=2,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for TD3."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_mujoco.load(env_name))

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        return train_loss
예제 #25
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        action_spec = tf_env.action_spec()
        num_actions = action_spec.maximum - action_spec.minimum + 1

        if train_sequence_length > 1:
            q_net = create_recurrent_network(input_fc_layer_params, lstm_size,
                                             output_fc_layer_params,
                                             num_actions)
        else:
            q_net = create_feedforward_network(fc_layer_params, num_actions)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
예제 #26
0
def train_eval(
        root_dir,
        env_name='gym_solventx-v0',
        eval_env_name=None,
        env_load_fn=suite_gym.load,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=5000,
        policy_checkpoint_interval=2500,
        rb_checkpoint_interval=25000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_env_name = eval_env_name or env_name
        gym_env = gym.make(env_name, config_file=config_file)
        py_env = suite_gym.wrap_env(gym_env, max_episode_steps=100)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_gym_env = gym.make(eval_env_name, config_file=config_file)
        eval_py_env = suite_gym.wrap_env(eval_gym_env, max_episode_steps=100)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        #tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        #eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if replay_buffer.num_frames() == 0:
            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
예제 #27
0
파일: dqn.py 프로젝트: mshinji/mahjong
def main():
    print(' === MAIN ===')

    # 環境の設定
    env_py = Env()
    env = tf_py_environment.TFPyEnvironment(env_py)
    print(' === ENV LOADED === ')

    # ネットワークの設定
    primary_network = MyQNetwork(env.observation_spec(), env.action_spec())
    print(' === NETWORK LOADED === ')

    # エージェントの設定
    n_step_update = 1
    agent = dqn_agent.DqnAgent(
        env.time_step_spec(),
        env.action_spec(),
        q_network=primary_network,
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        n_step_update=n_step_update,
        target_update_period=100,
        gamma=0.99,
        train_step_counter=tf.Variable(0),
        epsilon_greedy=0.0
    )
    agent.initialize()
    agent.train = common.function(agent.train)
    print(' === AGENT LOADED === ')

    # 行動の設定
    policy = agent.collect_policy
    print(' === POLICY LOADED === ')

    # データの保存の設定
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=env.batch_size,
        max_length=10**6
    )
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
        sample_batch_size=16,
        num_steps=n_step_update + 1
    ).prefetch(tf.data.experimental.AUTOTUNE)
    iterator = iter(dataset)
    print(' === BUFFER LOADED === ')

    # ポリシーの保存設定
    tf_policy_saver = policy_saver.PolicySaver(policy=agent.policy)
    print(' === SAVER LOADED === ')

    # 学習
    num_episodes = 200
    decay_episodes = 70
    epsilon = np.concatenate([np.linspace(start=1.0, stop=1.0, num=decay_episodes),
                              0.1 * np.ones(shape=(num_episodes - decay_episodes,)), ], 0)

    action_step_counter = 0
    replay_start_size = 100

    episode_average_loss = []

    for episode in range(1, num_episodes + 1):
        policy._epsilon = epsilon[episode - 1]  # ε-greedy法用
        env.reset()

        previous_time_step = None
        previous_policy_step = None

        while not env.game_end:  # ゲームが終わるまで繰り返す
            current_time_step = env.current_time_step()
            if previous_time_step is None:  # 1手目は学習データを作らない
                pass
            else:
                previous_step_reward = tf.constant([env.reward, ], dtype=tf.float32)
                current_time_step = current_time_step._replace(reward=previous_step_reward)

                traj = trajectory.from_transition(
                    previous_time_step, previous_policy_step, current_time_step)  # データの生成
                replay_buffer.add_batch(traj)  # データの保存

                if action_step_counter >= 2 * replay_start_size:  # 事前データ作成用
                    experience, _ = next(iterator)
                    loss_info = agent.train(experience=experience)  # 学習
                    episode_average_loss.append(loss_info.loss.numpy())
                else:
                    action_step_counter += 1
            if random.random() < epsilon[episode - 1]:  # ε-greedy法によるランダム動作
                policy_step = random_policy_step(env.random_action)  # 設定したランダムポリシー
            else:
                policy_step = policy.action(current_time_step)  # 状態から行動の決定

            previous_time_step = current_time_step  # 1つ前の状態の保存
            previous_policy_step = policy_step  # 1つ前の行動の保存

            env.step(policy_step.action)  # 石を配置

        print(env.stats)

        # 学習の進捗表示 (100エピソードごと)
        if episode % 100 == 0:
            print('==== Episode {}: rank: {} ===='.format(
                episode, env.stats
            ))
            episode_average_loss = []

        if episode % (num_episodes // 10) == 0:
            tf_policy_saver.save(f"policy_{episode}")
예제 #28
0
def train_eval(
        root_dir,
        env_name='SocialBot-ICubWalkPID-v0',
        num_iterations=10000000,
        actor_fc_layers=(256, 128),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 128),
        # Params for collect
        initial_collect_steps=2000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        num_parallel_environments=12,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=5e-4,
        critic_learning_rate=5e-4,
        alpha_learning_rate=5e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(
           parallel_py_environment.ParallelPyEnvironment(
               [lambda: suite_socialbot.load(env_name,wrap_with_process=False)] * num_parallel_environments))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_socialbot.load(env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=eval_policy,
            global_step=global_step)
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3, sample_batch_size=batch_size,
            num_steps=2).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (
                    global_step.numpy() - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_sec,
                    step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(
                    train_step=global_step, step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
예제 #29
0
def train_eval(
        root_dir,
        env_name='MaskedCartPole-v0',
        num_iterations=100000,
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),
        train_sequence_length=10,
        # Params for collect
        initial_collect_steps=50,
        collect_episodes_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=10,
        batch_size=128,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        q_net = q_rnn_network.QRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=input_fc_layer_params,
            lstm_size=lstm_size,
            output_fc_layer_params=output_fc_layer_params)

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        loss_info = common.function(tf_agent.train)(experience=experience)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            logging.info('Collecting initial experience.')
            sess.run(initial_collect_op)

            # Compute evaluation metrics.
            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(loss_info)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        log=True,
                        callback=eval_metrics_callback,
                    )
예제 #30
0
파일: dqn_tutorial.py 프로젝트: kbajay/ml
    num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3)

iterator = iter(dataset)

"""## Training the agent

The training loop involves both collecting data from the environment and optimizing the agent's networks. Along the way, we will occasionally evaluate the agent's policy to see how we are doing.

The following will take ~5 minutes to run.
"""

#@test {"skip": true}
# %%time

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)

# Reset the train step
tf_agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  for _ in range(collect_steps_per_iteration):
    collect_step(train_env, tf_agent.collect_policy)

  # Sample a batch of data from the buffer and update the agent's network.
예제 #31
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=1000,
    actor_fc_layers=(100,),
    value_net_fc_layers=(100,),
    use_value_network=False,
    use_tf_functions=True,
    # Params for collect
    collect_episodes_per_iteration=2,
    replay_buffer_capacity=2000,
    # Params for train
    learning_rate=1e-3,
    gamma=0.9,
    gradient_clipping=None,
    normalize_returns=True,
    value_estimation_loss_coef=0.2,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=100,
    # Params for checkpoints, summaries, and logging
    log_interval=100,
    summary_interval=100,
    summaries_flush_secs=1,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for Reinforce."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers)

    if use_value_network:
      value_net = value_network.ValueNetwork(
          tf_env.time_step_spec().observation,
          fc_layer_params=value_net_fc_layers)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    tf_agent = reinforce_agent.ReinforceAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        value_network=value_net if use_value_network else None,
        value_estimation_loss_coef=value_estimation_loss_coef,
        gamma=gamma,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        normalize_returns=normalize_returns,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    tf_agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration)

    def train_step():
      experience = replay_buffer.gather_all()
      return tf_agent.train(experience)

    if use_tf_functions:
      # To speed up collect use TF function.
      collect_driver.run = common.function(collect_driver.run)
      # To speed up train use TF function.
      tf_agent.train = common.function(tf_agent.train)
      train_step = common.function(train_step)

    # Compute evaluation metrics.
    metrics = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    # TODO(b/126590894): Move this functionality into eager_compute_summaries
    if eval_metrics_callback is not None:
      eval_metrics_callback(metrics, global_step.numpy())

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      total_loss = train_step()
      replay_buffer.clear()
      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()
      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, total_loss.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step_val % eval_interval == 0:
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        # TODO(b/126590894): Move this functionality into
        # eager_compute_summaries.
        if eval_metrics_callback is not None:
          eval_metrics_callback(metrics, global_step_val)