def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) FLAGS.set_default('max_timesteps', FLAGS.max_timesteps // FLAGS.action_repeat) if 'pixels-dm' in FLAGS.env_name: if 'distractor' in FLAGS.env_name: _, _, domain_name, _, _ = FLAGS.env_name.split('-') else: _, _, domain_name, _ = FLAGS.env_name.split('-') if domain_name in ['cartpole']: FLAGS.set_default('action_repeat', 8) elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']: FLAGS.set_default('action_repeat', 4) elif domain_name in ['finger', 'walker']: FLAGS.set_default('action_repeat', 2) print('Loading env') env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack) eval_env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack) print('Env loaded') else: raise Exception('Unsupported env') is_image_obs = (isinstance(env.observation_spec(), TensorSpec) and len(env.observation_spec().shape) == 3) spec = ( env.observation_spec(), env.action_spec(), env.reward_spec(), env.reward_spec(), # discount spec env.observation_spec() # next observation spec ) print('Init replay') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=1, max_length=FLAGS.max_length_replay_buffer) print('Replay created') @tf.function def add_to_replay(state, action, reward, discount, next_states): replay_buffer.add_batch((state, action, reward, discount, next_states)) hparam_str = utils.make_hparam_string( FLAGS.xm_parameters, seed=FLAGS.seed, env_name=FLAGS.env_name, algo_name=FLAGS.algo_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) results_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) print('Init actor') if 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG( env.observation_spec(), env.action_spec(), cross_norm='crossnorm' in FLAGS.algo_name) elif 'crr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='exp_mean') elif 'sac_v1' in FLAGS.algo_name: model = sac_v1.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'asac' in FLAGS.algo_name: model = asac.ASAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], cross_norm='crossnorm' in FLAGS.algo_name, pcl_actor_update='pc' in FLAGS.algo_name) elif 'pcl' in FLAGS.algo_name: model = pcl.PCL( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) print('Init random policy for warmup') initial_collect_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) print('Init replay buffer') dataset = replay_buffer.as_dataset( num_parallel_calls=tf.data.AUTOTUNE, sample_batch_size=FLAGS.sample_batch_size) if is_image_obs: dataset = dataset.map(image_aug, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False).prefetch(3) else: dataset = dataset.prefetch(3) def repack(*data): return data[0] dataset = dataset.map(repack) replay_buffer_iter = iter(dataset) previous_time = time.time() timestep = env.reset() episode_return = 0 episode_timesteps = 0 step_mult = 1 if FLAGS.action_repeat < 1 else FLAGS.action_repeat print('Starting training') for i in tqdm.tqdm(range(FLAGS.max_timesteps)): if i % FLAGS.deployment_batch_size == 0: for _ in range(FLAGS.deployment_batch_size): last_timestep = timestep.is_last() if last_timestep: if episode_timesteps > 0: current_time = time.time() with summary_writer.as_default(): tf.summary.scalar( 'train/returns', episode_return, step=(i + 1) * step_mult) tf.summary.scalar( 'train/FPS', episode_timesteps / (current_time - previous_time), step=(i + 1) * step_mult) timestep = env.reset() episode_return = 0 episode_timesteps = 0 previous_time = time.time() if (replay_buffer.num_frames() < FLAGS.num_random_actions or replay_buffer.num_frames() < FLAGS.deployment_batch_size): # Use policy only after the first deployment. policy_step = initial_collect_policy.action(timestep) action = policy_step.action else: action = model.actor(timestep.observation, sample=True) next_timestep = env.step(action) add_to_replay(timestep.observation, action, next_timestep.reward, next_timestep.discount, next_timestep.observation) episode_return += next_timestep.reward[0] episode_timesteps += 1 timestep = next_timestep if i + 1 >= FLAGS.start_training_timesteps: with summary_writer.as_default(): info_dict = model.update_step(replay_buffer_iter) if (i + 1) % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=(i + 1) * step_mult) if (i + 1) % FLAGS.eval_interval == 0: logging.info('Performing policy eval.') average_returns, evaluation_timesteps = evaluation.evaluate( eval_env, model) with results_writer.as_default(): tf.summary.scalar( 'evaluation/returns', average_returns, step=(i + 1) * step_mult) tf.summary.scalar( 'evaluation/length', evaluation_timesteps, step=(i+1) * step_mult) logging.info('Eval at %d: ave returns=%f, ave episode length=%f', (i + 1) * step_mult, average_returns, evaluation_timesteps) if ((i + 1) * step_mult) % 50_000 == 0: model.save_weights( os.path.join(FLAGS.save_dir, 'results', FLAGS.env_name + '__' + str( (i + 1) * step_mult)))
def load_model(checkpoint): checkpoint = int(checkpoint) print(checkpoint) if FLAGS.env_name.startswith('procgen'): env_id = [i for i, name in enumerate( PROCGEN_ENVS) if name == env_name][0]+1 if checkpoint == 10_000_000: ckpt_iter = '0000020480' elif checkpoint == 15_000_000: ckpt_iter = '0000030720' elif checkpoint == 20_000_000: ckpt_iter = '0000040960' elif checkpoint == 25_000_000: ckpt_iter = '0000051200' policy_weights_dir = ('ppo_darts/' '2021-06-22-16-36-54/%d/policies/checkpoints/' 'policy_checkpoint_%s/' % (env_id, ckpt_iter)) policy_def_dir = ('ppo_darts/' '2021-06-22-16-36-54/%d/policies/policy/' % (env_id)) model = py_tf_eager_policy.SavedModelPyTFEagerPolicy( policy_def_dir, time_step_spec=env._time_step_spec, # pylint: disable=protected-access action_spec=env._action_spec, # pylint: disable=protected-access policy_state_spec=env._observation_spec, # pylint: disable=protected-access info_spec=tf.TensorSpec(shape=(None,)), load_specs_from_pbtxt=False) model.update_from_checkpoint(policy_weights_dir) model.actor = model.action else: if 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG( env.observation_spec(), env.action_spec(), cross_norm='crossnorm' in FLAGS.algo_name) elif 'crr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='exp_mean') elif 'sac_v1' in FLAGS.algo_name: model = sac_v1.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'asac' in FLAGS.algo_name: model = asac.ASAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], cross_norm='crossnorm' in FLAGS.algo_name, pcl_actor_update='pc' in FLAGS.algo_name) elif 'pcl' in FLAGS.algo_name: model = pcl.PCL( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) if 'distractor' in FLAGS.env_name: ckpt_path = os.path.join( ('experiments/' '20210622_2023.policy_weights_sac_1M_dmc_distractor_hard_pixel/'), 'results', FLAGS.env_name+'__'+str(checkpoint)) else: ckpt_path = os.path.join( ('experiments/' '20210607_2023.policy_weights_dmc_1M_SAC_pixel'), 'results', FLAGS.env_name + '__' + str(checkpoint)) model.load_weights(ckpt_path) print('Loaded model weights') return model
def main(_): if FLAGS.eager: tf.config.experimental_run_functions_eagerly(FLAGS.eager) tf.random.set_seed(FLAGS.seed) # np.random.seed(FLAGS.seed) # random.seed(FLAGS.seed) if 'procgen' in FLAGS.env_name: _, env_name, train_levels, _ = FLAGS.env_name.split('-') env = procgen_wrappers.TFAgentsParallelProcGenEnv( 1, normalize_rewards=False, env_name=env_name, num_levels=int(train_levels), start_level=0) elif FLAGS.env_name.startswith('pixels-dm'): if 'distractor' in FLAGS.env_name: _, _, domain_name, _, _ = FLAGS.env_name.split('-') else: _, _, domain_name, _ = FLAGS.env_name.split('-') if domain_name in ['cartpole']: FLAGS.set_default('action_repeat', 8) elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']: FLAGS.set_default('action_repeat', 4) elif domain_name in ['finger', 'walker']: FLAGS.set_default('action_repeat', 2) env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack, FLAGS.obs_type) hparam_str = utils.make_hparam_string(FLAGS.xm_parameters, algo_name=FLAGS.algo_name, seed=FLAGS.seed, task_name=FLAGS.env_name, ckpt_timesteps=FLAGS.ckpt_timesteps) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) if FLAGS.env_name.startswith('procgen'): # map env string to digit [1,16] env_id = [ i for i, name in enumerate(PROCGEN_ENVS) if name == env_name ][0] + 1 if FLAGS.ckpt_timesteps == 10_000_000: ckpt_iter = '0000020480' elif FLAGS.ckpt_timesteps == 25_000_000: ckpt_iter = '0000051200' policy_weights_dir = ('ppo_darts/' '2021-06-22-16-36-54/%d/policies/checkpoints/' 'policy_checkpoint_%s/' % (env_id, ckpt_iter)) policy_def_dir = ('ppo_darts/' '2021-06-22-16-36-54/%d/policies/policy/' % (env_id)) model = py_tf_eager_policy.SavedModelPyTFEagerPolicy( policy_def_dir, time_step_spec=env._time_step_spec, # pylint: disable=protected-access action_spec=env._action_spec, # pylint: disable=protected-access policy_state_spec=env._observation_spec, # pylint: disable=protected-access info_spec=tf.TensorSpec(shape=(None, )), load_specs_from_pbtxt=False) model.update_from_checkpoint(policy_weights_dir) model = TfAgentsPolicy(model) else: if 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG(env.observation_spec(), env.action_spec(), cross_norm='crossnorm' in FLAGS.algo_name) elif 'crr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='exp_mean') elif 'sac_v1' in FLAGS.algo_name: model = sac_v1.SAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'asac' in FLAGS.algo_name: model = asac.ASAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], cross_norm='crossnorm' in FLAGS.algo_name, pcl_actor_update='pc' in FLAGS.algo_name) elif 'pcl' in FLAGS.algo_name: model = pcl.PCL(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) if 'distractor' in FLAGS.env_name: ckpt_path = os.path.join( ('experiments/20210622_2023.policy_weights_sac' '_1M_dmc_distractor_hard_pixel/'), 'results', FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps)) else: ckpt_path = os.path.join( ('experiments/20210607_2023.' 'policy_weights_dmc_1M_SAC_pixel'), 'results', FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps)) model.load_weights(ckpt_path) print('Loaded model weights') with summary_writer.as_default(): env = procgen_wrappers.TFAgentsParallelProcGenEnv( 1, normalize_rewards=False, env_name=env_name, num_levels=0, start_level=0) (avg_returns, avg_len) = evaluation.evaluate(env, model, num_episodes=100, return_distributions=False) tf.summary.scalar('evaluation/returns-all', avg_returns, step=0) tf.summary.scalar('evaluation/length-all', avg_len, step=0)
def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset( task_name=FLAGS.task_name, batch_size=FLAGS.batch_size) env = gym_wrapper.GymWrapper(gym_env) env = tf_py_environment.TFPyEnvironment(env) dataset_iter = iter(dataset) tf.random.set_seed(FLAGS.seed) hparam_str = utils.make_hparam_string(FLAGS.xm_parameters, algo_name=FLAGS.algo_name, seed=FLAGS.seed, task_name=FLAGS.task_name, data_name=FLAGS.data_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) result_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) if FLAGS.algo_name == 'bc': model = behavioral_cloning.BehavioralCloning(env.observation_spec(), env.action_spec()) elif FLAGS.algo_name == 'bc_mix': model = behavioral_cloning.BehavioralCloning(env.observation_spec(), env.action_spec(), mixture=True) elif 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG(env.observation_spec(), env.action_spec()) elif 'crr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='exp_mean') elif 'bcq' in FLAGS.algo_name: model = bcq.BCQ(env.observation_spec(), env.action_spec()) elif 'asac' in FLAGS.algo_name: model = asac.ASAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'cql' in FLAGS.algo_name: model = cql.CQL(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'brac' in FLAGS.algo_name: if 'fbrac' in FLAGS.algo_name: model = fisher_brac.FBRAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], f_reg=FLAGS.f_reg, reward_bonus=FLAGS.reward_bonus) else: model = brac.BRAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) model_folder = os.path.join( FLAGS.save_dir, 'models', f'{FLAGS.task_name}_{FLAGS.data_name}_{FLAGS.seed}') if not tf.gfile.io.isdir(model_folder): bc_pretraining_steps = 1_000_000 for i in tqdm.tqdm(range(bc_pretraining_steps)): info_dict = model.bc.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i - bc_pretraining_steps) # model.bc.policy.save_weights(os.path.join(model_folder, 'model')) else: model.bc.policy.load_weights(os.path.join(model_folder, 'model')) for i in tqdm.tqdm(range(FLAGS.num_updates)): with summary_writer.as_default(): info_dict = model.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i) if (i + 1) % FLAGS.eval_interval == 0: average_returns, average_length = evaluation.evaluate(env, model) if FLAGS.data_name is None: average_returns = gym_env.get_normalized_score( average_returns) * 100.0 with result_writer.as_default(): tf.summary.scalar('evaluation/returns', average_returns, step=i + 1) tf.summary.scalar('evaluation/length', average_length, step=i + 1)