def testSampleDoesNotCrossHead(self): np.random.seed(12345) data_spec = array_spec.ArraySpec((), np.int32) replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( data_spec=data_spec, capacity=10) # Seed RB with 5 elements to move head to position 5. for _ in range(5): replay_buffer.add_batch(np.array([0])) # Fill RB with elements 0-9. for i in range(10): replay_buffer.add_batch(np.array([i])) # Sample with num_steps = 2. We should never sample (9, 0) since this is an # invalid transition. With 1000 samples, the probability of sampling (9, 0) # if it were not protected against would be (1 - (9/10)^10000) ~= 1. sample_frequency = [0 for _ in range(10)] for _ in range(10000): (first, second) = replay_buffer.get_next(num_steps=2, time_stacked=False) self.assertNotEqual(np.array(9), first) self.assertNotEqual(np.array(0), second) sample_frequency[first] += 1 # 0-9 should all have been sampled about 10000/9 ~= 1111. We allow a delta # of 150 off of 1111 -- the chance each sample frequency is within this # range is 99.9998% (computed using the pmf of the binomial distribution). # And since we fix the random seed, this test is repeatable. for i in range(9): self.assertAlmostEqual(10000 / 9, sample_frequency[i], delta=150)
def __init__(self, playerIndex, debug=False, create_model=True): """ Initialize an agent. """ super().__init__(playerIndex, debug=debug) self.trainable = True # Whether to use small numbers for debugging reasons self.use_small_numbers = use_small_nums # Hyperparameters self.alpha = 0.01 # learning rate self.gamma = 0.95 # favour future rewards self.exploration_decay_rate = 1 / 2000 self.reward_win_round = 0.005 self.reward_per_card_played = 0.001 self.rewards = { 0: 1.0, # No other agent finished before 1: 0.05, # One other agent finished before 2: 0.04, # Two other agents finished before 3: -1.0, # Three other agents finished before } # Training/Batch parameters self.sample_batch = 64 if self.use_small_numbers else 512 self.replay_capacity = 128 if self.use_small_numbers else 1024 self.train_each_n_steps = 5 if self.use_small_numbers else 50 self.step_iteration = 0 self.model_data_spec = ( # TODO adjust to new model tf.TensorSpec([4 * 13], tf.int8, "board_state"), tf.TensorSpec([1], tf.float32, "q_value"), ) self.replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=self.replay_capacity, data_spec=tensor_spec.to_nest_array_spec(self.model_data_spec) ) # Validation parameters self.val_replay_capacity = 20 if self.use_small_numbers else 200 self.validation_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=self.val_replay_capacity, data_spec=tensor_spec.to_nest_array_spec(self.model_data_spec) ) # Initialize model if create_model: self._create_model()
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, log_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # Note this is a python environment. env = batched_py_environment.BatchedPyEnvironment( [suite_gym.load(env_name)]) eval_py_env = suite_gym.load(env_name) # Convert specs to BoundedTensorSpec. action_spec = tensor_spec.from_spec(env.action_spec()) observation_spec = tensor_spec.from_spec(env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()), tensor_spec.from_spec(env.action_spec()), fc_layer_params=fc_layer_params) # The agent must be in graph. global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_collect_policy = agent.collect_policy collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy) greedy_policy = py_tf_policy.PyTFPolicy(agent.policy) random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) # Python replay buffer. replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec)) time_step = env.reset() # Initialize the replay buffer with some transitions. We use the random # policy to initialize the replay buffer to make sure we get a good # distribution of actions. for _ in range(initial_collect_steps): time_step = collect_step(env, time_step, random_policy, replay_buffer) # TODO(b/112041045) Use global_step as counter. train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=agent.policy, global_step=global_step) ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) itr = tf.compat.v1.data.make_initializable_iterator(ds) experience = itr.get_next() train_op = common.function(agent.train)(experience) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) with tf.compat.v1.Session() as session: train_checkpointer.initialize_or_restore(session) common.initialize_uninitialized_variables(session) session.run(itr.initializer) # Copy critic network values to the target critic network. session.run(agent.initialize()) train = session.make_callable(train_op) global_step_call = session.make_callable(global_step) session.run(train_summary_writer.init()) session.run(eval_summary_writer.init()) # Compute initial evaluation metrics. global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) timed_at_step = global_step_val collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() for _ in range(collect_steps_per_iteration): time_step = collect_step(env, time_step, collect_policy, replay_buffer) collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss = train() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) session.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) # Reset timing to avoid counting eval time. timed_at_step = global_step_val start_time = time.time()