def bigtable_collect( root_dir, env_name='CartPole-v0', num_iterations=100000, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect num_episodes=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') global_step = tf.compat.v1.train.get_or_create_global_step() tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy #INSTANTIATE CBT TABLE AND GCS BUCKET credentials = service_account.Credentials.from_service_account_file( SERVICE_ACCOUNT_FILE, scopes=SCOPES) cbt_table, gcs_bucket = gcp_load_pipeline(args.gcp_project_id, args.cbt_instance_id, args.cbt_table_name, args.bucket_id, credentials) max_row_bytes = (4 * np.prod(VISUAL_OBS_SPEC) + 64) cbt_batcher = cbt_table.mutations_batcher(flush_count=args.num_episodes, max_row_bytes=max_row_bytes) bigtable_replay_buffer = BigtableReplayBuffer( data_spec=tf_agent.collect_data_spec, max_size=replay_buffer_capacity) collect_driver = dynamic_episode_driver.DynamicStepDriver( tf_env, collect_policy, observers=[bigtable_replay_buffer.add_batch] + train_metrics, num_episodes=num_episodes) # train_checkpointer = common.Checkpointer( # ckpt_dir=train_dir, # agent=tf_agent, # global_step=global_step, # metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) # policy_checkpointer = common.Checkpointer( # ckpt_dir=os.path.join(train_dir, 'policy'), # policy=eval_policy, # global_step=global_step) # rb_checkpointer = common.Checkpointer( # ckpt_dir=os.path.join(train_dir, 'replay_buffer'), # max_to_keep=1, # replay_buffer=bigreplay_buffer) # train_checkpointer.initialize_or_restore() # rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 for _ in range(num_iterations): collect_driver.run()
parser.add_argument('--bucket-id', type=str, default='rab-rl-bucket') parser.add_argument('--prefix', type=str, default='breakout') parser.add_argument('--tmp-weights-filepath', type=str, default='/tmp/model_weights_tmp.h5') parser.add_argument('--num-cycles', type=int, default=1000000) parser.add_argument('--num-episodes', type=int, default=3) parser.add_argument('--max-steps', type=int, default=1000) parser.add_argument('--log-time', default=False, action='store_true') args = parser.parse_args() #INSTANTIATE CBT TABLE AND GCS BUCKET credentials = service_account.Credentials.from_service_account_file( SERVICE_ACCOUNT_FILE, scopes=SCOPES) cbt_table, gcs_bucket = gcp_load_pipeline(args.gcp_project_id, args.cbt_instance_id, args.cbt_table_name, args.bucket_id, credentials) cbt_table_visual = cbt_load_table(args.gcp_project_id, args.cbt_instance_id, args.cbt_table_name + 'visual', credentials) cbt_batcher = cbt_table.mutations_batcher(flush_count=args.num_episodes, max_row_bytes=10080100) #cbt_batcher_visual = cbt_table_visual.mutations_batcher(flush_count=args.num_episodes, max_row_bytes=10080100) #INITIALIZE ENVIRONMENT print("-> Initializing Gym environement...") #Custom DQN #env = gym.make('Breakout-v0') env_name = 'Breakout-v0'