Exemplo n.º 1
0
def main(unused_argv):
    tf.compat.v1.enable_v2_behavior()  # The trainer only runs with V2 enabled.

    with tf.device('/CPU:0'):  # due to b/128333994
        action_reward_fns = (
            environment_utilities.sliding_linear_reward_fn_generator(
                CONTEXT_DIM, NUM_ACTIONS, REWARD_NOISE_VARIANCE))

        env = sspe.StationaryStochasticPyEnvironment(functools.partial(
            environment_utilities.context_sampling_fn,
            batch_size=BATCH_SIZE,
            context_dim=CONTEXT_DIM),
                                                     action_reward_fns,
                                                     batch_size=BATCH_SIZE)
        environment = tf_py_environment.TFPyEnvironment(env)

        optimal_reward_fn = functools.partial(
            environment_utilities.tf_compute_optimal_reward,
            per_action_reward_fns=action_reward_fns)

        optimal_action_fn = functools.partial(
            environment_utilities.tf_compute_optimal_action,
            per_action_reward_fns=action_reward_fns)

        q_net = q_network.QNetwork(environment.observation_spec(),
                                   environment.action_spec(),
                                   fc_layer_params=(50, 50))

        agent = dqn_agent.DqnAgent(
            environment.time_step_spec(),
            environment.action_spec(),
            q_network=q_net,
            epsilon_greedy=0.1,
            target_update_tau=0.05,
            target_update_period=5,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=1e-2),
            td_errors_loss_fn=common.element_wise_squared_loss)

        regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
        suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
            optimal_action_fn)

        trainer.train(
            root_dir=FLAGS.root_dir,
            agent=agent,
            environment=environment,
            training_loops=TRAINING_LOOPS,
            steps_per_loop=STEPS_PER_LOOP,
            additional_metrics=[regret_metric, suboptimal_arms_metric])
Exemplo n.º 2
0
    def _init_agent(self, mac, learning_rate, q_net, train_step_counter,
                    train_env):
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        self.agent[mac] = dqn_agent.DqnAgent(
            train_env.time_step_spec(),
            train_env.action_spec(),
            q_network=q_net,
            optimizer=optimizer,
            td_errors_loss_fn=common.element_wise_squared_loss,
            train_step_counter=train_step_counter)
        self.agent[mac].initialize()

        return
Exemplo n.º 3
0
    def create_agent(self):
        q_net = q_network.QNetwork(self.env.observation_spec(),
                                   self.env.action_spec(),
                                   fc_layer_params=self.fc_layer_params)
        self.tf_agent = dqn_agent.DqnAgent(
            self.env.time_step_spec(),
            self.env.action_spec(),
            q_network=q_net,
            optimizer=self.optimizer,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            train_step_counter=self.train_step_counter,
            gamma=self.gamma)

        self.init_steps = 0
        self.episode_steps = 0
Exemplo n.º 4
0
def get_agent(train_env):
    fc_layer_params = (100, )

    q_net = q_network.QNetwork(train_env.observation_spec(),
                               train_env.action_spec(),
                               fc_layer_params=fc_layer_params)
    train_step_counter = tf.Variable(0)

    return dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter)
Exemplo n.º 5
0
    def def_agent(self):
        """Define the agent (architecture of the solver)"""

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=self.learning_rate)
        train_step_counter = tf.Variable(0)

        self.agent = dqn_agent.DqnAgent(
            self.train_env.time_step_spec(),
            self.train_env.action_spec(),
            q_network=self.q_agent,
            optimizer=optimizer,
            td_errors_loss_fn=common.element_wise_squared_loss,
            train_step_counter=train_step_counter)

        self.agent.initialize()
def train(num_iterations):
    train_env = tf_py_environment.TFPyEnvironment(Cliff())
    test_env = tf_py_environment.TFPyEnvironment(Cliff())
    counter = tf.Variable(0)

    # Build network
    network = q_network.QNetwork(train_env.observation_spec(),
                                 train_env.action_spec(),
                                 fc_layer_params=(100, ))
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=network,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3),
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=counter)

    agent.initialize()
    agent.train = common.function(agent.train)
    agent.train_step_counter.assign(0)

    buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=100)
    dataset = buffer.as_dataset(sample_batch_size=32, num_steps=2)
    iterator = iter(dataset)
    first_reward = compute_average_reward(train_env,
                                          agent.policy,
                                          num_episodes=10)
    print(f'Before training: {first_reward}')
    rewards = [first_reward]

    for _ in range(num_iterations):
        for _ in range(2):
            collect_steps(train_env, agent.collect_policy, buffer)

        experience, info = next(iterator)
        loss = agent.train(experience).loss
        step_number = agent.train_step_counter.numpy()

        if step_number % 10 == 0:
            print(f'step={step_number}: loss={loss}')

        if step_number % 20 == 0:
            average_reward = compute_average_reward(test_env, agent.policy, 1)
            print(f'step={step_number}: Reward:={average_reward}')
Exemplo n.º 7
0
    def __init__(self, *args, **kwargs):
        self.env = tf_py_environment.TFPyEnvironment(CardGameEnv())
        self.q_net = q_network.QNetwork(self.env.observation_spec(),
                                        self.env.action_spec())
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.train_step_counter = tf.Variable(0)
        self.agent = dqn_agent.DqnAgent(
            self.env.time_step_spec(),
            self.env.action_spec(),
            q_network=self.q_net,
            optimizer=self.optimizer,
            td_errors_loss_fn=common.element_wise_squared_loss,
            train_step_counter=self.train_step_counter)

        self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=self.env.batch_size,
            max_length=100000)
        self.num_iterations = 10000
Exemplo n.º 8
0
def generic_dqn_agent(env: TFPyEnvironment) -> (dqn_agent.DqnAgent, q_network.QNetwork):
    """ Function that returns a generic dqn agent
    args:
        env (TFPyEnvironment) : The environment the agent will live in

    Returns:
        dqn_agent.DqnAgent: The agent to train
        q_network.QNetwork: The network used in the agent
    """

    inp = env.observation_spec().shape[0]
    q_net = q_network.QNetwork(
      env.observation_spec(),
      env.action_spec(),
      fc_layer_params=(20,20,20,20,20),
      activation_fn=tf.keras.activations.relu)

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)

    agent = dqn_agent.DqnAgent(
      env.time_step_spec(),
      env.action_spec(),
      q_network=q_net,
      optimizer=optimizer,
      td_errors_loss_fn=common.element_wise_squared_loss,
      train_step_counter=tf.Variable(0),
      epsilon_greedy=0.1
    )

    """def observation_and_action_constraint_splitter(observation):
        action_mask = [1,1]
        if observation[0][-1] > 5:
            action_mask[0] = 1
        return observation, tf.convert_to_tensor(action_mask, dtype=np.int32)

    agent.policy._observation_and_action_constraint_splitter = (
        observation_and_action_constraint_splitter
    )"""
    #tf_agents.policies.greedy_policy.GreedyPolicy

    agent.initialize()

    return agent, q_net
  def __init__(self, config, *args, **kwargs):
    """Initialize the agent."""
    self.observation_spec = config['observation_spec']
    self.time_step_spec = ts.time_step_spec(self.observation_spec)
    self.action_spec = config['action_spec']
    self.environment_batch_size = config['environment_batch_size']

    self.q_net = q_rnn_network.QRnnNetwork(
      self.observation_spec,
      self.action_spec,
      input_fc_layer_params=(256, ),
      lstm_size=(256, 256))

    self.optimizer = tf.keras.optimizers.Adam()

    self.agent = dqn_agent.DqnAgent(
      self.time_step_spec,
      self.action_spec,
      q_network=self.q_net,
      optimizer=self.optimizer,
      epsilon_greedy=0.3)

    self.agent.initialize()

    self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      data_spec=self.agent.collect_data_spec,
      batch_size=1,
      dataset_drop_remainder=True)

    self.dataset = self.replay_buffer.as_dataset(
      sample_batch_size=1,
      num_steps=2,
      single_deterministic_pass=True)

    self.intention_classifier = tf.keras.models.Sequential(
      [tf.keras.layers.Dense(1, input_shape=self.observation_spec.shape, activation='sigmoid')])

    self.intention_classifier.compile(optimizer=self.optimizer, loss=tf.keras.losses.BinaryCrossentropy())

    self.intention_dataset_input = []
    self.intention_dataset_true = []

    self.num_hinted = 0
Exemplo n.º 10
0
 def initialize_agent(self):
     """ Instance TF agent with hparams
     """
     # Q network
     self.q_net = q_network.QNetwork(
         self.train_env.observation_spec(),
         self.train_env.action_spec(),
         fc_layer_params=self.qnet_fc_hidden_size)
     # DQN agent
     self.agent = dqn_agent.DqnAgent(
         self.train_env.time_step_spec(),
         self.train_env.action_spec(),
         q_network=self.q_net,
         optimizer=self.optimizer,
         epsilon_greedy=0.1,  # [TODO] - add the hyper param
         td_errors_loss_fn=common.element_wise_squared_loss,
         train_step_counter=self.train_step_counter)
     self.agent.initialize()
     self.policy = self.agent.policy
Exemplo n.º 11
0
def create_agent(train_env):
  q_net = q_network.QNetwork(
      train_env.observation_spec(),
      train_env.action_spec(),
      fc_layer_params=fc_layer_params,
  )

  adam = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

  train_step_counter = tf.compat.v2.Variable(0)

  tf_agent = dqn_agent.DqnAgent(
      train_env.time_step_spec(),
      train_env.action_spec(),
      q_network=q_net,
      optimizer=adam,
      td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
      train_step_counter=train_step_counter,
  )
  tf_agent.initialize()
  return tf_agent
Exemplo n.º 12
0
    def __init__(self,
                 train_environment,
                 eval_environment,
                 replay_buffer_capacity=1000,
                 fc_layer_params=(100, ),
                 learning_rate=1e-3):
        # Use TF Environment Wrappers to translate them for TF
        self.train_env = tf_py_environment.TFPyEnvironment(train_environment)
        self.eval_env = tf_py_environment.TFPyEnvironment(eval_environment)

        # Define Q-Network
        q_net = q_network.QNetwork(self.train_env.observation_spec(),
                                   self.train_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        train_step_counter = tf.compat.v2.Variable(0)
        # Define Agent
        self.agent = dqn_agent.DqnAgent(
            self.train_env.time_step_spec(),
            self.train_env.action_spec(),
            q_network=q_net,
            optimizer=optimizer,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            train_step_counter=train_step_counter)

        self.agent.initialize()

        self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=train_env.batch_size,
            max_length=replay_buffer_capacity)

        self.eval_policy = self.agent.policy
        self.collect_policy = self.agent.collect_policy

        self.random_policy = random_tf_policy.RandomTFPolicy(
            self.train_env.time_step_spec(), self.train_env.action_spec())
Exemplo n.º 13
0
def run():
    env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=1000))

    ## Needs to be the same network from training
    q_net = q_network.QNetwork(
        env.observation_spec(),
        env.action_spec(),
        conv_layer_params=(),
        fc_layer_params=(256, 100),
    )

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
    global_counter = tf.compat.v1.train.get_or_create_global_step()

    agent = dqn_agent.DqnAgent(
        env.time_step_spec(),
        env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_counter,
        gamma=0.95,
        epsilon_greedy=0.1,
        n_step_update=1,
    )

    agent.initialize()

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=agent.policy,
        global_step=global_counter,
    )

    policy_checkpointer.initialize_or_restore()

    capture_run(
        os.path.join(root_dir, "snake" + str(global_counter.numpy()) + ".mp4"),
        env, agent.policy)
Exemplo n.º 14
0
  def build_and_run_actor():
    root_dir = test_case.create_tempdir().full_path
    env, action_tensor_spec, time_step_tensor_spec = (
        get_cartpole_env_and_specs())

    train_step = train_utils.create_train_step()

    q_net = build_dummy_sequential_net(fc_layer_params=(100,),
                                       action_spec=action_tensor_spec)

    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        train_step_counter=train_step)

    _, rb_observer = (
        replay_buffer_utils.get_reverb_buffer_and_observer(
            agent.collect_data_spec,
            table_name=reverb_replay_buffer.DEFAULT_TABLE,
            sequence_length=2,
            reverb_server_address='localhost:{}'.format(reverb_server_port)))

    variable_container = reverb_variable_container.ReverbVariableContainer(
        server_address='localhost:{}'.format(reverb_server_port),
        table_names=[reverb_variable_container.DEFAULT_TABLE])

    test_actor = build_actor(
        root_dir, env, agent, rb_observer, train_step)

    variables_dict = {
        reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(),
        reverb_variable_container.TRAIN_STEP_KEY: train_step
    }
    variable_container.update(variables_dict)

    for _ in range(num_iterations):
      test_actor.run()
Exemplo n.º 15
0
def create_dqn_agent_and_dataset_fn(action_spec, time_step_spec, train_step,
                                    batch_size):
    """Builds and returns a dataset function for DQN Agent."""
    q_net = build_dummy_sequential_net(fc_layer_params=(100, ),
                                       action_spec=action_spec)

    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        train_step_counter=train_step)
    agent.initialize()

    def make_item(_):
        traj = tensor_spec.sample_spec_nest(agent.collect_data_spec,
                                            seed=123,
                                            outer_dims=[2])

        def scale_observation_only(item):
            # Scale float values in the sampled item by large value to avoid NaNs.
            if item.dtype == tf.float32:
                return tf.math.divide(item, 1.e+22)
            else:
                return item

        return tf.nest.map_structure(scale_observation_only, traj)

    l = []
    for i in range(100):
        l.append([i])
    dataset = tf.data.Dataset.zip(
        (tf.data.Dataset.from_tensor_slices(l).map(make_item),
         tf.data.experimental.Counter()))
    dataset_fn = lambda: dataset.batch(batch_size)

    return agent, dataset, dataset_fn, agent.collect_data_spec
Exemplo n.º 16
0
                           activation_fn=tf.keras.activations.relu,
                           kernel_initializer=None,
                           batch_squash=True,
                           dtype=tf.float32,
                           name='QNetwork'
                           )
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(train_tf_env.time_step_spec(),
                           train_tf_env.action_spec(),
                           q_network=q_net,
                           epsilon_greedy=0.9,
                           observation_and_action_constraint_splitter=None,
                           # boltzmann_temperature= 0.5,
                           # emit_log_probability= True,
                           gamma=0.9,
                           optimizer=optimizer,
                           target_update_period=1000,
                           td_errors_loss_fn=common.element_wise_squared_loss,
                           train_step_counter=train_step_counter,
                           summarize_grads_and_vars=True,
                           debug_summaries=True)
agent.initialize()
##############
# Policy setup#
##############
# Todo: Design random policy
'''
eval_policy = agent.policy
collect_policy = agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_tf_env.time_step_spec(),
Exemplo n.º 17
0
def bigtable_collect(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        num_episodes=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    global_step = tf.compat.v1.train.get_or_create_global_step()
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

    q_net = q_network.QNetwork(tf_env.observation_spec(),
                               tf_env.action_spec(),
                               fc_layer_params=fc_layer_params)
    train_sequence_length = n_step_update

    # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
    tf_agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    #INSTANTIATE CBT TABLE AND GCS BUCKET
    credentials = service_account.Credentials.from_service_account_file(
        SERVICE_ACCOUNT_FILE, scopes=SCOPES)
    cbt_table, gcs_bucket = gcp_load_pipeline(args.gcp_project_id,
                                              args.cbt_instance_id,
                                              args.cbt_table_name,
                                              args.bucket_id, credentials)
    max_row_bytes = (4 * np.prod(VISUAL_OBS_SPEC) + 64)
    cbt_batcher = cbt_table.mutations_batcher(flush_count=args.num_episodes,
                                              max_row_bytes=max_row_bytes)

    bigtable_replay_buffer = BigtableReplayBuffer(
        data_spec=tf_agent.collect_data_spec, max_size=replay_buffer_capacity)

    collect_driver = dynamic_episode_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[bigtable_replay_buffer.add_batch] + train_metrics,
        num_episodes=num_episodes)

    # train_checkpointer = common.Checkpointer(
    #     ckpt_dir=train_dir,
    #     agent=tf_agent,
    #     global_step=global_step,
    #     metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    # policy_checkpointer = common.Checkpointer(
    #     ckpt_dir=os.path.join(train_dir, 'policy'),
    #     policy=eval_policy,
    #     global_step=global_step)
    # rb_checkpointer = common.Checkpointer(
    #     ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
    #     max_to_keep=1,
    #     replay_buffer=bigreplay_buffer)

    # train_checkpointer.initialize_or_restore()
    # rb_checkpointer.initialize_or_restore()

    if use_tf_functions:
        # To speed up collect use common.function.
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    for _ in range(num_iterations):
        collect_driver.run()
Exemplo n.º 18
0
def load_agents_and_create_videos(root_dir,
        env_name='CartPole-v0',
        num_iterations=NUM_ITERATIONS,
        max_ep_steps=1000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=((128,64,32)),
        # Params for QRnnNetwork
        input_fc_layer_params=(50,),
        lstm_size=(20,),
        output_fc_layer_params=(20,),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=10000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval 
        num_eval_episodes=10,
        num_random_episodes=1,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None,
        random_metrics_callback=None):
    
    
    
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    random_dir = os.path.join(root_dir, 'random')
    
    eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    
    # Match the environments used in training
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name, max_episode_steps=max_ep_steps))
    eval_py_env = suite_gym.load(env_name, max_episode_steps=max_ep_steps)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    if train_sequence_length != 1 and n_step_update != 1:
        raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

    if train_sequence_length > 1:
        q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
    else:
        q_net = q_network.QNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=fc_layer_params)

        train_sequence_length = n_step_update

    # Match the agents used in training
    tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
    
    tf_agent.initialize()

    train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),]

    eval_policy = tf_agent.policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

    train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))

    policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=eval_policy,
            global_step=global_step)

    rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

    # Load the data from training
    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    # Define a random policy for comparison
    random_policy = random_tf_policy.RandomTFPolicy(eval_tf_env.time_step_spec(),
                                                    eval_tf_env.action_spec())

    # Make movies of the trained agent and a random agent
    date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
    
    trained_filename = "trained-agent" + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename)

    random_filename = 'random-agent ' + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
Exemplo n.º 19
0
def main():
    parser = argparse.ArgumentParser()

    ## Essential parameters
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model stats and checkpoints will be written."
    )
    parser.add_argument("--env",
                        default=None,
                        type=str,
                        required=True,
                        help="The environment to train the agent on")
    parser.add_argument("--approx_env_boundaries",
                        default=False,
                        type=bool,
                        help="Whether to get the env boundaries approximately")
    parser.add_argument("--max_horizon", default=4, type=int)
    parser.add_argument("--atari",
                        default=True,
                        type=bool,
                        help="Gets some data Types correctly")

    ##agent parameters
    parser.add_argument("--reward_scale_factor", default=1.0, type=float)
    parser.add_argument("--debug_summaries", default=False, type=bool)
    parser.add_argument("--summarize_grads_and_vars", default=False, type=bool)

    ##transformer parameters
    parser.add_argument("--d_model", default=64, type=int)
    parser.add_argument("--num_layers", default=2, type=int)
    parser.add_argument("--dff", default=256, type=int)

    ##Training parameters
    parser.add_argument('--num_iterations',
                        type=int,
                        default=2000000,
                        help="steps in the env")
    parser.add_argument('--num_iparallel',
                        type=int,
                        default=1,
                        help="how many envs should run in parallel")
    parser.add_argument("--collect_steps_per_iteration", default=4, type=int)
    parser.add_argument("--train_steps_per_iteration", default=1, type=int)

    ## Other parameters
    parser.add_argument("--num_eval_episodes", default=10, type=int)
    parser.add_argument("--eval_interval", default=10000, type=int)
    parser.add_argument("--log_interval", default=10000, type=int)
    parser.add_argument("--summary_interval", default=10000, type=int)
    parser.add_argument("--run_graph_mode", default=True, type=bool)
    parser.add_argument("--checkpoint_interval", default=100000, type=int)
    parser.add_argument("--summary_flush", default=10,
                        type=int)  #what does this exactly do?

    # HP opt params
    parser.add_argument("--doubleQ",
                        default=True,
                        type=bool,
                        help="Whether to use a  DoubleQ agent")
    parser.add_argument("--custom_last_layer", default=True, type=bool)
    parser.add_argument("--custom_layer_init", default=1.0, type=float)
    parser.add_argument("--initial_collect_steps", default=50000, type=int)
    parser.add_argument("--loss_function",
                        default="element_wise_huber_loss",
                        type=str)
    parser.add_argument("--num_heads", default=4, type=int)
    parser.add_argument("--normalize_env", default=False, type=bool)
    parser.add_argument('--custom_lr_schedule',
                        default="No",
                        type=str,
                        help="whether to use a custom LR schedule")
    parser.add_argument("--epsilon_greedy", default=0.3, type=float)
    parser.add_argument("--target_update_period", default=10000, type=int)
    parser.add_argument(
        "--rate", default=0.1, type=float
    )  # dropout rate  (might be not used depending on the q network)  #Setting this to 0.0 somehow break the code. Not relevant tho just select a network without dropout
    parser.add_argument("--gradient_clipping", default=True, type=bool)
    parser.add_argument("--replay_buffer_max_length",
                        default=1000000,
                        type=int)
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--learning_rate", default=1e-4, type=float)
    parser.add_argument("--encoder_type",
                        default=3,
                        type=int,
                        help="Which Type of encoder is used for the model")
    parser.add_argument("--layer_type",
                        default=3,
                        type=int,
                        help="Which Type of layer is used for the encoder")
    parser.add_argument("--target_update_tau", default=1, type=float)
    parser.add_argument("--gamma", default=0.99, type=float)

    args = parser.parse_args()
    # List of encoder modules which we can use to change encoder based on a variable
    global_step = tf.compat.v1.train.get_or_create_global_step()

    baseEnv = gym.make(args.env)
    env = suite_gym.load(args.env, gym_kwargs={"frameskip": 4})
    eval_env = suite_gym.load(args.env, gym_kwargs={"frameskip": 4})
    if args.normalize_env == True:
        env = NormalizeWrapper(env, args.approx_env_boundaries, args.env)
        eval_env = NormalizeWrapper(eval_env, args.approx_env_boundaries,
                                    args.env)
    env = PyhistoryWrapper(env, args.max_horizon, args.atari)
    eval_env = PyhistoryWrapper(eval_env, args.max_horizon, args.atari)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env)

    q_net = QTransformer(tf_env.observation_spec(),
                         baseEnv.action_space.n,
                         num_layers=args.num_layers,
                         d_model=args.d_model,
                         num_heads=args.num_heads,
                         dff=args.dff,
                         rate=args.rate,
                         encoderType=args.encoder_type,
                         enc_layer_type=args.layer_type,
                         max_horizon=args.max_horizon,
                         custom_layer=args.custom_layer_init,
                         custom_last_layer=args.custom_last_layer)

    if args.custom_lr_schedule == "Transformer":  # builds a lr schedule according to the original usage for the transformer
        learning_rate = CustomSchedule(args.d_model,
                                       int(args.num_iterations / 10))
        optimizer = tf.keras.optimizers.Adam(learning_rate,
                                             beta_1=0.9,
                                             beta_2=0.98,
                                             epsilon=1e-9)

    elif args.custom_lr_schedule == "Transformer_low":  # builds a lr schedule according to the original usage for the transformer
        learning_rate = CustomSchedule(
            int(args.d_model / 2),
            int(args.num_iterations /
                10))  # --> same schedule with lower general lr
        optimizer = tf.keras.optimizers.Adam(learning_rate,
                                             beta_1=0.9,
                                             beta_2=0.98,
                                             epsilon=1e-9)

    elif args.custom_lr_schedule == "Linear":
        lrs = LinearCustomSchedule(learning_rate, args.num_iterations)
        optimizer = tf.keras.optimizers.Adam(lrs,
                                             beta_1=0.9,
                                             beta_2=0.98,
                                             epsilon=1e-9)

    else:
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=args.learning_rate)

    if args.loss_function == "element_wise_huber_loss":
        lf = element_wise_huber_loss
    elif args.loss_function == "element_wise_squared_loss":
        lf = element_wise_squared_loss

    if args.doubleQ == False:  # global step count
        agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=args.epsilon_greedy,
            target_update_tau=args.target_update_tau,
            target_update_period=args.target_update_period,
            td_errors_loss_fn=lf,
            optimizer=optimizer,
            gamma=args.gamma,
            reward_scale_factor=args.reward_scale_factor,
            gradient_clipping=args.gradient_clipping,
            debug_summaries=args.debug_summaries,
            summarize_grads_and_vars=args.summarize_grads_and_vars,
            train_step_counter=global_step)
    else:
        agent = dqn_agent.DdqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=args.epsilon_greedy,
            target_update_tau=args.target_update_tau,
            td_errors_loss_fn=lf,
            target_update_period=args.target_update_period,
            optimizer=optimizer,
            gamma=args.gamma,
            reward_scale_factor=args.reward_scale_factor,
            gradient_clipping=args.gradient_clipping,
            debug_summaries=args.debug_summaries,
            summarize_grads_and_vars=args.summarize_grads_and_vars,
            train_step_counter=global_step)
    agent.initialize()

    count_weights(q_net)

    train_eval(root_dir=args.output_dir,
               tf_env=tf_env,
               eval_tf_env=eval_tf_env,
               agent=agent,
               num_iterations=args.num_iterations,
               initial_collect_steps=args.initial_collect_steps,
               collect_steps_per_iteration=args.collect_steps_per_iteration,
               replay_buffer_capacity=args.replay_buffer_max_length,
               train_steps_per_iteration=args.train_steps_per_iteration,
               batch_size=args.batch_size,
               use_tf_functions=args.run_graph_mode,
               num_eval_episodes=args.num_eval_episodes,
               eval_interval=args.eval_interval,
               train_checkpoint_interval=args.checkpoint_interval,
               policy_checkpoint_interval=args.checkpoint_interval,
               rb_checkpoint_interval=args.checkpoint_interval,
               log_interval=args.log_interval,
               summary_interval=args.summary_interval,
               summaries_flush_secs=args.summary_flush)

    pickle.dump(args, open(args.output_dir + "/training_args.p", "wb"))
    print("Successfully trained and evaluation.")
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(
    ),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40, ),  # lstm size for actor
    critic_lstm_size=(40, ),  # lstm size for critic
    actor_output_fc_layers=(100, ),  # lstm output
    critic_output_fc_layers=(100, ),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1
):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
    # Setup GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    # Get train and eval directories
    root_dir = os.path.expanduser(root_dir)
    root_dir = os.path.join(root_dir, env_name, experiment_name)

    # Get summary writers
    summary_writer = tf.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    # Eval metrics
    eval_metrics = [
        tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Whether to record for summary
    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create Carla environment
        if agent_name == 'latent_sac':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names +
                                                 mask_names,
                                                 action_repeat=action_repeat)
        elif agent_name == 'dqn':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 discrete=True,
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)
        else:
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)

        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
        fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

        # Specs
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        ## Make tf agent
        if agent_name == 'latent_sac':
            # Get model network for latent sac
            if model_network_ctor_type == 'hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
            elif model_network_ctor_type == 'non-hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
            else:
                raise NotImplementedError
            model_net = model_network_ctor(input_names,
                                           input_names + mask_names)

            # Get the latent spec
            latent_size = model_net.latent_size
            latent_observation_spec = tensor_spec.TensorSpec((latent_size, ),
                                                             dtype=tf.float32)
            latent_time_step_spec = ts.time_step_spec(
                observation_spec=latent_observation_spec)

            # Get actor and critic net
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                latent_observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=normal_projection_net)
            critic_net = critic_network.CriticNetwork(
                (latent_observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers)

            # Build the inner SAC agent based on latent space
            inner_agent = sac_agent.SacAgent(
                latent_time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
            inner_agent.initialize()

            # Build the latent sac agent
            tf_agent = latent_sac_agent.LatentSACAgent(
                time_step_spec,
                action_spec,
                inner_agent=inner_agent,
                model_network=model_net,
                model_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=model_learning_rate),
                model_batch_size=model_batch_size,
                num_images_per_summary=num_images_per_summary,
                sequence_length=sequence_length,
                gradient_clipping=gradient_clipping,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                fps=fps)

        else:
            # Set up preprosessing layers for dictionary observation inputs
            preprocessing_layers = collections.OrderedDict()
            for name in input_names:
                preprocessing_layers[name] = Preprocessing_Layer(32, 256)
            if len(input_names) < 2:
                preprocessing_combiner = None

            if agent_name == 'dqn':
                q_rnn_net = q_rnn_network.QRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_rnn_net,
                    epsilon_greedy=epsilon_greedy,
                    n_step_update=1,
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=q_learning_rate),
                    td_errors_loss_fn=common.element_wise_squared_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            elif agent_name == 'ddpg' or agent_name == 'td3':
                actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                if agent_name == 'ddpg':
                    tf_agent = ddpg_agent.DdpgAgent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        ou_stddev=ou_stddev,
                        ou_damping=ou_damping,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars)
                elif agent_name == 'td3':
                    tf_agent = td3_agent.Td3Agent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        exploration_noise_std=exploration_noise_std,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        actor_update_period=actor_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars,
                        train_step_counter=global_step)

            elif agent_name == 'sac':
                actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers,
                    continuous_projection_net=normal_projection_net)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = sac_agent.SacAgent(
                    time_step_spec,
                    action_spec,
                    actor_network=actor_distribution_rnn_net,
                    critic_network=critic_rnn_net,
                    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=actor_learning_rate),
                    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=critic_learning_rate),
                    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=alpha_learning_rate),
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    td_errors_loss_fn=tf.math.
                    squared_difference,  # make critic loss dimension compatible
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            else:
                raise NotImplementedError

        tf_agent.initialize()

        # Get replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,  # No parallel environments
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        # Train metrics
        env_steps = tf_metrics.EnvironmentSteps()
        average_return = tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        # Get policies
        # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_policy = tf_agent.policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        collect_policy = tf_agent.collect_policy

        # Checkpointers
        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        # Collect driver
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        # Optimize the performance by using tf functions
        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        if agent_name == 'latent_sac':
            compute_summaries(eval_metrics,
                              eval_tf_env,
                              eval_policy,
                              train_step=global_step,
                              summary_writer=summary_writer,
                              num_episodes=1,
                              num_episodes_to_render=1,
                              model_net=model_net,
                              fps=10,
                              image_keys=input_names + mask_names)
        else:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=1,
                train_step=env_steps.result(),
                summary_writer=summary_writer,
                summary_prefix='Eval',
            )
            metric_utils.log_metrics(eval_metrics)

        # Dataset generates trajectories with shape [Bxslx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        # Get train step
        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        train_step = common.function(train_step)

        if agent_name == 'latent_sac':

            def train_model_step():
                experience, _ = next(iterator)
                return tf_agent.train_model(experience)

            train_model_step = common.function(train_model_step)

        # Training initializations
        time_step = None
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Start training
        for iteration in range(num_iterations):
            start_time = time.time()

            if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
                train_model_step()
            else:
                # Run collect
                time_step, _ = collect_driver.run(time_step=time_step)

                # Train an iteration
                for _ in range(train_steps_per_iteration):
                    train_step()

            time_acc += time.time() - start_time

            # Log training information
            if global_step.numpy() % log_interval == 0:
                logging.info('env steps = %d, average return = %f',
                             env_steps.result(), average_return.result())
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / time_acc
                logging.info('%.3f env steps/sec', env_steps_per_sec)
                tf.summary.scalar(name='env_steps_per_sec',
                                  data=env_steps_per_sec,
                                  step=env_steps.result())
                time_acc = 0
                env_steps_before = env_steps.result().numpy()

            # Get training metrics
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())

            # Evaluation
            if global_step.numpy() % eval_interval == 0:
                # Log evaluation metrics
                if agent_name == 'latent_sac':
                    compute_summaries(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        train_step=global_step,
                        summary_writer=summary_writer,
                        num_episodes=num_eval_episodes,
                        num_episodes_to_render=num_images_per_summary,
                        model_net=model_net,
                        fps=10,
                        image_keys=input_names + mask_names)
                else:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    metric_utils.log_metrics(eval_metrics)

            # Save checkpoints
            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
Exemplo n.º 21
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        if train_sequence_length > 1:
            q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
        else:
            q_net = q_network.QNetwork(tf_env.observation_spec(),
                                       tf_env.action_spec(),
                                       fc_layer_params=fc_layer_params)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
Exemplo n.º 22
0
def run():
    tf_env = tf_py_environment.TFPyEnvironment(SnakeEnv())
    eval_env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=50))

    q_net = q_network.QNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        conv_layer_params=(),
        fc_layer_params=(512, 256, 128),
    )

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    global_counter = tf.compat.v1.train.get_or_create_global_step()

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_counter,
        gamma=0.95,
        epsilon_greedy=0.1,
        n_step_update=1,
    )

    root_dir = os.path.join('/tf-logs', 'snake')
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_max_length,
    )

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        agent.collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration,
    )

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=agent,
        global_step=global_counter,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
    )

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=agent.policy,
        global_step=global_counter,
    )

    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer,
    )

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    collect_driver.run = common.function(collect_driver.run)
    agent.train = common.function(agent.train)

    random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(),
                                                    tf_env.action_spec())

    if replay_buffer.num_frames() >= initial_collect_steps:
        logging.info("We loaded memories, not doing random seed")
    else:
        logging.info("Capturing %d steps to seed with random memories",
                     initial_collect_steps)

        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            random_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

    train_summary_writer = tf.summary.create_file_writer(train_dir)
    train_summary_writer.set_as_default()

    avg_returns = []
    avg_return_metric = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes)
    eval_metrics = [
        avg_return_metric,
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    logging.info("Running initial evaluation")
    results = metric_utils.eager_compute(
        eval_metrics,
        eval_env,
        agent.policy,
        num_episodes=num_eval_episodes,
        train_step=global_counter,
        summary_writer=tf.summary.create_file_writer(eval_dir),
        summary_prefix='Metrics',
    )
    avg_returns.append(
        (global_counter.numpy(), avg_return_metric.result().numpy()))
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_counter.numpy()
    time_acc = 0

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)

    iterator = iter(dataset)

    @common.function
    def train_step():
        experience, _ = next(iterator)
        return agent.train(experience)

    for _ in range(num_iterations):
        start_time = time.time()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )

        for _ in range(train_steps_per_iteration):
            train_loss = train_step()
        time_acc += time.time() - start_time

        step = global_counter.numpy()

        if step % log_interval == 0:
            logging.info("step = %d, loss = %f", step, train_loss.loss)
            steps_per_sec = (step - timed_at_step) / time_acc
            logging.info("%.3f steps/sec", steps_per_sec)
            timed_at_step = step
            time_acc = 0

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_counter,
                                      step_metrics=train_metrics[:2])

        if step % train_checkpoint_interval == 0:
            train_checkpointer.save(global_step=step)

        if step % policy_checkpoint_interval == 0:
            policy_checkpointer.save(global_step=step)

        if step % rb_checkpoint_interval == 0:
            rb_checkpointer.save(global_step=step)

        if step % capture_interval == 0:
            print("Capturing run:")
            capture_run(os.path.join(root_dir, "snake" + str(step) + ".mp4"),
                        eval_env, agent.policy)

        if step % eval_interval == 0:
            print("EVALUTION TIME:")
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                agent.policy,
                num_episodes=num_eval_episodes,
                train_step=global_counter,
                summary_writer=tf.summary.create_file_writer(eval_dir),
                summary_prefix='Metrics',
            )
            metric_utils.log_metrics(eval_metrics)
            avg_returns.append(
                (global_counter.numpy(), avg_return_metric.result().numpy()))
Exemplo n.º 23
0
def ai_game():
    pygame.init()
    display = pygame.display.set_mode((HEIGHT, WIDTH))
    pygame.display.set_caption("Snake")
    font = pygame.font.SysFont("Times New Roman", 24)
    # snake_agent = SnakeAgent()
    # game(display, snake_agent)
    time.sleep(5)

    train_env = SnakeGameEnv(display, font)
    eval_env = SnakeGameEnv(display, font)
    # env = CardGameEnv()
    # utils.validate_py_environment(env)
    fc_layer_params = (100, 50)
    action_tensor_spec = tensor_spec.from_spec(train_env.action_spec())
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
    train_env = tf_py_environment.TFPyEnvironment(train_env)
    eval_env = tf_py_environment.TFPyEnvironment(eval_env)

    # QNetwork consists of a sequence of Dense layers followed by a dense layer
    # with `num_actions` units to generate one q_value per available action as
    # it's output.
    dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
    q_values_layer = tf.keras.layers.Dense(
        num_actions,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(
            minval=-0.03, maxval=0.03),
        bias_initializer=tf.keras.initializers.Constant(-0.2))
    q_net = sequential.Sequential(dense_layers + [q_values_layer])
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    train_step_counter = tf.Variable(0)
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter)
    agent.initialize()
    print('Initialized Agent')
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                    train_env.action_spec())

    print('Reset time spec')
    time_step = train_env.reset()
    random_policy.action(time_step)
    print('Successfully instantiated random policy')
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=REPLAY_BUFFER_MAX_LEN)

    print('Created replay buffer, collecting data ... ')
    collect_data(train_env, random_policy, replay_buffer, INITIAL_COLLECT_STEPS)
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=BATCH_SIZE,
        num_steps=2).prefetch(3)
    print('Collecting data complete')
    iterator = iter(dataset)
    # Reset the train step
    agent.train_step_counter.assign(0)
    avg_return = compute_avg_return(train_env, agent.policy, NUM_EVAL_EPISODES)
    returns = [avg_return]
    print('Beginning to train...')
    for i in range(NUM_ITERATIONS):
        collect_data(train_env, agent.collect_policy, replay_buffer, COLLECT_STEPS_PER_ITERATION)

        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss
        step = agent.train_step_counter.numpy()
        #print(train_env.time_step_spec())

        print(f"Training agent through iteration {(i / NUM_ITERATIONS) * 100:.2f}%...")
        if step % LOG_INTERVAL == 0:
            pass
            #print('step = {0}: loss = {1}'.format(step, train_loss))

        if step % EVAL_INTERVAL == 0:
            #avg_return = compute_avg_return(train_env, agent.policy, NUM_EVAL_EPISODES)
            #print('step = {0}: Average Return = {1}'.format(step, avg_return))
            #returns.append(avg_return)
            pass

    """
Exemplo n.º 24
0
def main():
    def compute_avg_return(environment, policy, num_episodes=10):
        total_return = 0.0
        for _ in range(num_episodes):
            time_step = environment.reset()
            episode_return = 0.0

            while not time_step.is_last():
                action_step = policy.action(time_step)
                time_step = environment.step(action_step.action)
                episode_return += time_step.reward
            total_return += episode_return

        avg_return = total_return / num_episodes
        return avg_return.numpy()[0]

    class ShowProgress:
        def __init__(self, total):
            self.counter = 0
            self.total = total

        def __call__(self, trajectory):
            if not trajectory.is_boundary():
                self.counter += 1
            if self.counter % 100 == 0:
                print("\r{}/{}".format(self.counter, self.total), end="")

    def train_agent(n_iterations, save_each=10000, print_each=500):
        time_step = None
        policy_state = agent.collect_policy.get_initial_state(
            tf_env.batch_size)
        iterator = iter(dataset)

        for iteration in range(n_iterations):
            step = agent.train_step_counter.numpy()
            current_metrics = []

            time_step, policy_state = collect_driver.run(
                time_step, policy_state)
            trajectories, buffer_info = next(iterator)

            train_loss = agent.train(trajectories)
            all_train_loss.append(train_loss.loss.numpy())

            for i in range(len(train_metrics)):
                current_metrics.append(train_metrics[i].result().numpy())

            all_metrics.append(current_metrics)

            if iteration % print_each == 0:
                print("\nIteration: {}, loss:{:.2f}".format(
                    iteration, train_loss.loss.numpy()))

                for i in range(len(train_metrics)):
                    print('{}: {}'.format(train_metrics[i].name,
                                          train_metrics[i].result().numpy()))

            if step % EVAL_INTERVAL == 0:
                avg_return = compute_avg_return(eval_tf_env, agent.policy,
                                                NUM_EVAL_EPISODES)
                print(f'Step = {step}, Average Return = {avg_return}')
                returns.append((step, avg_return))

            if step % save_each == 0:
                print("Saving model")
                train_checkpointer.save(train_step)
                policy_save_handler.save("policy")
                with open("checkpoint/train_loss.pickle", "wb") as f:
                    pickle.dump(all_train_loss, f)
                with open("checkpoint/all_metrics.pickle", "wb") as f:
                    pickle.dump(all_metrics, f)
                with open("checkpoint/returns.pickle", "wb") as f:
                    pickle.dump(returns, f)

    eval_tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment())

    #tf_env = tf_py_environment.TFPyEnvironment(
    #   parallel_py_environment.ParallelPyEnvironment(
    #       [BombermanEnvironment] * N_PARALLEL_ENVIRONMENTS
    #   ))

    tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment())

    q_net = QNetwork(tf_env.observation_spec(),
                     tf_env.action_spec(),
                     conv_layer_params=[(32, 3, 1), (32, 3, 1)],
                     fc_layer_params=[128, 64, 32])

    train_step = tf.Variable(0)
    update_period = 4
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)  # todo fine tune

    epsilon_fn = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=0.7,
        decay_steps=25000 // update_period,
        end_learning_rate=0.01)

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=0.99,
        train_step_counter=train_step,
        epsilon_greedy=lambda: epsilon_fn(train_step))

    agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=10000)
    replay_buffer_observer = replay_buffer.add_batch

    train_metrics = [
        tf_metrics.AverageReturnMetric(batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(batch_size=tf_env.batch_size)
    ]

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        agent.collect_policy,
        observers=[replay_buffer_observer] + train_metrics,
        num_steps=update_period)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    initial_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[
            replay_buffer.add_batch,
            ShowProgress(INITIAL_COLLECT_STEPS)
        ],
        num_steps=INITIAL_COLLECT_STEPS)
    final_time_step, final_policy_state = initial_driver.run()

    dataset = replay_buffer.as_dataset(sample_batch_size=64,
                                       num_steps=2,
                                       num_parallel_calls=3).prefetch(3)

    agent.train = common.function(agent.train)

    all_train_loss = []
    all_metrics = []
    returns = []

    checkpoint_dir = "checkpoint/"
    train_checkpointer = common.Checkpointer(ckpt_dir=checkpoint_dir,
                                             max_to_keep=1,
                                             agent=agent,
                                             policy=agent.policy,
                                             replay_buffer=replay_buffer,
                                             global_step=train_step)
    # train_checkpointer.initialize_or_restore()
    # train_step = tf.compat.v1.train.get_global_step()
    policy_save_handler = policy_saver.PolicySaver(agent.policy)

    # training here
    train_agent(2000)

    # save at end in every case

    policy_save_handler.save("policy")
Exemplo n.º 25
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        log_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # Note this is a python environment.
    env = batched_py_environment.BatchedPyEnvironment(
        [suite_gym.load(env_name)])
    eval_py_env = suite_gym.load(env_name)

    # Convert specs to BoundedTensorSpec.
    action_spec = tensor_spec.from_spec(env.action_spec())
    observation_spec = tensor_spec.from_spec(env.observation_spec())
    time_step_spec = ts.time_step_spec(observation_spec)

    q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()),
                               tensor_spec.from_spec(env.action_spec()),
                               fc_layer_params=fc_layer_params)

    # The agent must be in graph.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy)
    greedy_policy = py_tf_policy.PyTFPolicy(agent.policy)
    random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                                    env.action_spec())

    # Python replay buffer.
    replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=replay_buffer_capacity,
        data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec))

    time_step = env.reset()

    # Initialize the replay buffer with some transitions. We use the random
    # policy to initialize the replay buffer to make sure we get a good
    # distribution of actions.
    for _ in range(initial_collect_steps):
        time_step = collect_step(env, time_step, random_policy, replay_buffer)

    # TODO(b/112041045) Use global_step as counter.
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step)

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=agent.policy,
                                              global_step=global_step)

    ds = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                  num_steps=n_step_update + 1)
    ds = ds.prefetch(4)
    itr = tf.compat.v1.data.make_initializable_iterator(ds)

    experience = itr.get_next()

    train_op = common.function(agent.train)(experience)

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
        for eval_metric in eval_metrics:
            eval_metric.tf_summaries(train_step=global_step)

    with tf.compat.v1.Session() as session:
        train_checkpointer.initialize_or_restore(session)
        common.initialize_uninitialized_variables(session)
        session.run(itr.initializer)
        # Copy critic network values to the target critic network.
        session.run(agent.initialize())
        train = session.make_callable(train_op)
        global_step_call = session.make_callable(global_step)
        session.run(train_summary_writer.init())
        session.run(eval_summary_writer.init())

        # Compute initial evaluation metrics.
        global_step_val = global_step_call()
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            greedy_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            log=True,
            callback=eval_metrics_callback,
        )

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0
        steps_per_second_ph = tf.compat.v1.placeholder(tf.float32,
                                                       shape=(),
                                                       name='steps_per_sec_ph')
        steps_per_second_summary = tf.compat.v2.summary.scalar(
            name='global_steps_per_sec',
            data=steps_per_second_ph,
            step=global_step)

        for _ in range(num_iterations):
            start_time = time.time()
            for _ in range(collect_steps_per_iteration):
                time_step = collect_step(env, time_step, collect_policy,
                                         replay_buffer)
            collect_time += time.time() - start_time
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                loss = train()
            train_time += time.time() - start_time
            global_step_val = global_step_call()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             loss.loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                session.run(steps_per_second_summary,
                            feed_dict={steps_per_second_ph: steps_per_sec})
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    '%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % eval_interval == 0:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    greedy_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    log=True,
                    callback=eval_metrics_callback,
                )
                # Reset timing to avoid counting eval time.
                timed_at_step = global_step_val
                start_time = time.time()
Exemplo n.º 26
0
    def __init__(
            self,
            root_dir,
            env_name,
            num_iterations=200,
            max_episode_frames=108000,  # ALE frames
            terminal_on_life_loss=False,
            conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3),
                                                                  1)),
            fc_layer_params=(512, ),
            # Params for collect
            initial_collect_steps=80000,  # ALE frames
            epsilon_greedy=0.01,
            epsilon_decay_period=1000000,  # ALE frames
            replay_buffer_capacity=1000000,
            # Params for train
            train_steps_per_iteration=1000000,  # ALE frames
            update_period=16,  # ALE frames
            target_update_tau=1.0,
            target_update_period=32000,  # ALE frames
            batch_size=32,
            learning_rate=2.5e-4,
            n_step_update=1,
            gamma=0.99,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for eval
            do_eval=True,
            eval_steps_per_iteration=500000,  # ALE frames
            eval_epsilon_greedy=0.001,
            # Params for checkpoints, summaries, and logging
            log_interval=1000,
            summary_interval=1000,
            summaries_flush_secs=10,
            debug_summaries=False,
            summarize_grads_and_vars=False,
            eval_metrics_callback=None):
        """A simple Atari train and eval for DQN.

    Args:
      root_dir: Directory to write log files to.
      env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0).
      num_iterations: Number of train/eval iterations to run.
      max_episode_frames: Maximum length of a single episode, in ALE frames.
      terminal_on_life_loss: Whether to simulate an episode termination when a
        life is lost.
      conv_layer_params: Params for convolutional layers of QNetwork.
      fc_layer_params: Params for fully connected layers of QNetwork.
      initial_collect_steps: Number of frames to ALE frames to process before
        beginning to train. Since this is in ALE frames, there will be
        initial_collect_steps/4 items in the replay buffer when training starts.
      epsilon_greedy: Final epsilon value to decay to for training.
      epsilon_decay_period: Period over which to decay epsilon, from 1.0 to
        epsilon_greedy (defined above).
      replay_buffer_capacity: Maximum number of items to store in the replay
        buffer.
      train_steps_per_iteration: Number of ALE frames to run through for each
        iteration of training.
      update_period: Run a train operation every update_period ALE frames.
      target_update_tau: Coeffecient for soft target network updates (1.0 ==
        hard updates).
      target_update_period: Period, in ALE frames, to copy the live network to
        the target network.
      batch_size: Number of frames to include in each training batch.
      learning_rate: RMS optimizer learning rate.
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Applies standard single-step updates when set to 1.
      gamma: Discount for future rewards.
      reward_scale_factor: Scaling factor for rewards.
      gradient_clipping: Norm length to clip gradients.
      do_eval: If True, run an eval every iteration. If False, skip eval.
      eval_steps_per_iteration: Number of ALE frames to run through for each
        iteration of evaluation.
      eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 ==
        totally greedy policy).
      log_interval: Log stats to the terminal every log_interval training
        steps.
      summary_interval: Write TF summaries every summary_interval training
        steps.
      summaries_flush_secs: Flush summaries to disk every summaries_flush_secs
        seconds.
      debug_summaries: If True, write additional summaries for debugging (see
        dqn_agent for which summaries are written).
      summarize_grads_and_vars: Include gradients in summaries.
      eval_metrics_callback: A callback function that takes (metric_dict,
        global_step) as parameters. Called after every eval with the results of
        the evaluation.
    """
        self._update_period = update_period / ATARI_FRAME_SKIP
        self._train_steps_per_iteration = (train_steps_per_iteration /
                                           ATARI_FRAME_SKIP)
        self._do_eval = do_eval
        self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP
        self._eval_epsilon_greedy = eval_epsilon_greedy
        self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP
        self._summary_interval = summary_interval
        self._num_iterations = num_iterations
        self._log_interval = log_interval
        self._eval_metrics_callback = eval_metrics_callback

        with gin.unlock_config():
            gin.bind_parameter('AtariPreprocessing.terminal_on_life_loss',
                               terminal_on_life_loss)

        root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(root_dir, 'train')
        eval_dir = os.path.join(root_dir, 'eval')

        train_summary_writer = tf.compat.v2.summary.create_file_writer(
            train_dir, flush_millis=summaries_flush_secs * 1000)
        train_summary_writer.set_as_default()
        self._train_summary_writer = train_summary_writer

        self._eval_summary_writer = None
        if self._do_eval:
            self._eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_metrics = [
                py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                               buffer_size=np.inf),
                py_metrics.AverageEpisodeLengthMetric(
                    name='PhaseAverageEpisodeLength', buffer_size=np.inf),
            ]

        self._global_step = tf.compat.v1.train.get_or_create_global_step()
        with tf.compat.v2.summary.record_if(lambda: tf.math.equal(
                self._global_step % self._summary_interval, 0)):
            self._env = suite_atari.load(
                env_name,
                max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP,
                gym_env_wrappers=suite_atari.
                DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
            self._env = batched_py_environment.BatchedPyEnvironment(
                [self._env])

            observation_spec = tensor_spec.from_spec(
                self._env.observation_spec())
            time_step_spec = ts.time_step_spec(observation_spec)
            action_spec = tensor_spec.from_spec(self._env.action_spec())

            with tf.device('/cpu:0'):
                epsilon = tf.compat.v1.train.polynomial_decay(
                    1.0,
                    self._global_step,
                    epsilon_decay_period / ATARI_FRAME_SKIP /
                    self._update_period,
                    end_learning_rate=epsilon_greedy)

            with tf.device('/gpu:0'):
                optimizer = tf.compat.v1.train.RMSPropOptimizer(
                    learning_rate=learning_rate,
                    decay=0.95,
                    momentum=0.0,
                    epsilon=0.00001,
                    centered=True)
                q_net = AtariQNetwork(observation_spec,
                                      action_spec,
                                      conv_layer_params=conv_layer_params,
                                      fc_layer_params=fc_layer_params)
                agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_net,
                    optimizer=optimizer,
                    epsilon_greedy=epsilon,
                    n_step_update=n_step_update,
                    target_update_tau=target_update_tau,
                    target_update_period=(target_update_period /
                                          ATARI_FRAME_SKIP /
                                          self._update_period),
                    td_errors_loss_fn=common.element_wise_huber_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=self._global_step)

                self._collect_policy = py_tf_policy.PyTFPolicy(
                    agent.collect_policy)

                if self._do_eval:
                    self._eval_policy = py_tf_policy.PyTFPolicy(
                        epsilon_greedy_policy.EpsilonGreedyPolicy(
                            policy=agent.policy,
                            epsilon=self._eval_epsilon_greedy))

                py_observation_spec = self._env.observation_spec()
                py_time_step_spec = ts.time_step_spec(py_observation_spec)
                py_action_spec = policy_step.PolicyStep(
                    self._env.action_spec())
                data_spec = trajectory.from_transition(py_time_step_spec,
                                                       py_action_spec,
                                                       py_time_step_spec)
                self._replay_buffer = py_hashed_replay_buffer.PyHashedReplayBuffer(
                    data_spec=data_spec, capacity=replay_buffer_capacity)

            with tf.device('/cpu:0'):
                ds = self._replay_buffer.as_dataset(
                    sample_batch_size=batch_size, num_steps=n_step_update + 1)
                ds = ds.prefetch(4)
                ds = ds.apply(
                    tf.data.experimental.prefetch_to_device('/gpu:0'))

            with tf.device('/gpu:0'):
                self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds)
                experience = self._ds_itr.get_next()
                self._train_op = agent.train(experience)

                self._env_steps_metric = py_metrics.EnvironmentSteps()
                self._step_metrics = [
                    py_metrics.NumberOfEpisodes(),
                    self._env_steps_metric,
                ]
                self._train_metrics = self._step_metrics + [
                    py_metrics.AverageReturnMetric(buffer_size=10),
                    py_metrics.AverageEpisodeLengthMetric(buffer_size=10),
                ]
                # The _train_phase_metrics average over an entire train iteration,
                # rather than the rolling average of the last 10 episodes.
                self._train_phase_metrics = [
                    py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                                   buffer_size=np.inf),
                    py_metrics.AverageEpisodeLengthMetric(
                        name='PhaseAverageEpisodeLength', buffer_size=np.inf),
                ]
                self._iteration_metric = py_metrics.CounterMetric(
                    name='Iteration')

                # Summaries written from python should run every time they are
                # generated.
                with tf.compat.v2.summary.record_if(True):
                    self._steps_per_second_ph = tf.compat.v1.placeholder(
                        tf.float32, shape=(), name='steps_per_sec_ph')
                    self._steps_per_second_summary = tf.compat.v2.summary.scalar(
                        name='global_steps_per_sec',
                        data=self._steps_per_second_ph,
                        step=self._global_step)

                    for metric in self._train_metrics:
                        metric.tf_summaries(train_step=self._global_step,
                                            step_metrics=self._step_metrics)

                    for metric in self._train_phase_metrics:
                        metric.tf_summaries(
                            train_step=self._global_step,
                            step_metrics=(self._iteration_metric, ))
                    self._iteration_metric.tf_summaries(
                        train_step=self._global_step)

                    if self._do_eval:
                        with self._eval_summary_writer.as_default():
                            for metric in self._eval_metrics:
                                metric.tf_summaries(
                                    train_step=self._global_step,
                                    step_metrics=(self._iteration_metric, ))

                self._train_dir = train_dir
                self._policy_exporter = policy_saver.PolicySaver(
                    agent.policy, train_step=self._global_step)
                self._train_checkpointer = common.Checkpointer(
                    ckpt_dir=train_dir,
                    agent=agent,
                    global_step=self._global_step,
                    optimizer=optimizer,
                    metrics=metric_utils.MetricsGroup(
                        self._train_metrics + self._train_phase_metrics +
                        [self._iteration_metric], 'train_metrics'))
                self._policy_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'policy'),
                    policy=agent.policy,
                    global_step=self._global_step)
                self._rb_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
                    max_to_keep=1,
                    replay_buffer=self._replay_buffer)

                self._init_agent_op = agent.initialize()
Exemplo n.º 27
0
q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

#optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate)

#train_step_counter = tf.Variable(0)
global_step = tf.compat.v1.train.get_or_create_global_step()

agent = dqn_agent.DqnAgent(
#agent = dqn_agent.DdqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    #train_step_counter=train_step_counter,
    train_step_counter=global_step,
    epsilon_greedy=epsilon_greedy)

agent.initialize()

eval_policy = agent.policy
collect_policy = agent.collect_policy

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
Exemplo n.º 28
0
_eval_env = tf_py_environment.TFPyEnvironment(_eval_py_env)

print('Building Network...')
_fc_layer_params = (512, )

_q_net = q_network.QNetwork(_train_env.observation_spec(),
                            _train_env.action_spec(),
                            fc_layer_params=_fc_layer_params)

_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=_learning_rate)

_train_step_counter = tf.Variable(0)

_agent = dqn_agent.DqnAgent(_train_env.time_step_spec(),
                            _train_env.action_spec(),
                            q_network=_q_net,
                            optimizer=_optimizer,
                            td_errors_loss_fn=common.element_wise_squared_loss,
                            train_step_counter=_train_step_counter)

_agent.initialize()

_eval_policy = _agent.policy
_collect_policy = _agent.collect_policy

_random_policy = random_tf_policy.RandomTFPolicy(_train_env.time_step_spec(),
                                                 _train_env.action_spec())


def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in tqdm(range(num_episodes)):
Exemplo n.º 29
0
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)

tf_agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
    train_step_counter=train_step_counter)

tf_agent.initialize()

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=tf_agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

print("Batch Size: {}".format(train_env.batch_size))
Exemplo n.º 30
0
batch_size = 64
learning_rate = 1e-3
num_eval_episodes = 5
env_name = 'CartPole-v0'

train_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_py_env = suite_gym.load(env_name)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=(100, ))
agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=tf.optimizers.Adam(learning_rate=learning_rate),
    td_errors_loss_fn=common.element_wise_squared_loss)
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())
observers = [replay_buffer.add_batch]

dynamic_step_driver.DynamicStepDriver(train_env,
                                      random_policy,
                                      observers=observers,
                                      num_steps=initial_collect_steps).run()