def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) FLAGS.set_default('max_timesteps', FLAGS.max_timesteps // FLAGS.action_repeat) if 'pixels-dm' in FLAGS.env_name: if 'distractor' in FLAGS.env_name: _, _, domain_name, _, _ = FLAGS.env_name.split('-') else: _, _, domain_name, _ = FLAGS.env_name.split('-') if domain_name in ['cartpole']: FLAGS.set_default('action_repeat', 8) elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']: FLAGS.set_default('action_repeat', 4) elif domain_name in ['finger', 'walker']: FLAGS.set_default('action_repeat', 2) print('Loading env') env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack) eval_env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack) print('Env loaded') else: raise Exception('Unsupported env') is_image_obs = (isinstance(env.observation_spec(), TensorSpec) and len(env.observation_spec().shape) == 3) spec = ( env.observation_spec(), env.action_spec(), env.reward_spec(), env.reward_spec(), # discount spec env.observation_spec() # next observation spec ) print('Init replay') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=1, max_length=FLAGS.max_length_replay_buffer) print('Replay created') @tf.function def add_to_replay(state, action, reward, discount, next_states): replay_buffer.add_batch((state, action, reward, discount, next_states)) hparam_str = utils.make_hparam_string( FLAGS.xm_parameters, seed=FLAGS.seed, env_name=FLAGS.env_name, algo_name=FLAGS.algo_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) results_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) print('Init actor') if 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG( env.observation_spec(), env.action_spec(), cross_norm='crossnorm' in FLAGS.algo_name) elif 'crr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='exp_mean') elif 'sac_v1' in FLAGS.algo_name: model = sac_v1.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'asac' in FLAGS.algo_name: model = asac.ASAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], cross_norm='crossnorm' in FLAGS.algo_name, pcl_actor_update='pc' in FLAGS.algo_name) elif 'pcl' in FLAGS.algo_name: model = pcl.PCL( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) print('Init random policy for warmup') initial_collect_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) print('Init replay buffer') dataset = replay_buffer.as_dataset( num_parallel_calls=tf.data.AUTOTUNE, sample_batch_size=FLAGS.sample_batch_size) if is_image_obs: dataset = dataset.map(image_aug, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False).prefetch(3) else: dataset = dataset.prefetch(3) def repack(*data): return data[0] dataset = dataset.map(repack) replay_buffer_iter = iter(dataset) previous_time = time.time() timestep = env.reset() episode_return = 0 episode_timesteps = 0 step_mult = 1 if FLAGS.action_repeat < 1 else FLAGS.action_repeat print('Starting training') for i in tqdm.tqdm(range(FLAGS.max_timesteps)): if i % FLAGS.deployment_batch_size == 0: for _ in range(FLAGS.deployment_batch_size): last_timestep = timestep.is_last() if last_timestep: if episode_timesteps > 0: current_time = time.time() with summary_writer.as_default(): tf.summary.scalar( 'train/returns', episode_return, step=(i + 1) * step_mult) tf.summary.scalar( 'train/FPS', episode_timesteps / (current_time - previous_time), step=(i + 1) * step_mult) timestep = env.reset() episode_return = 0 episode_timesteps = 0 previous_time = time.time() if (replay_buffer.num_frames() < FLAGS.num_random_actions or replay_buffer.num_frames() < FLAGS.deployment_batch_size): # Use policy only after the first deployment. policy_step = initial_collect_policy.action(timestep) action = policy_step.action else: action = model.actor(timestep.observation, sample=True) next_timestep = env.step(action) add_to_replay(timestep.observation, action, next_timestep.reward, next_timestep.discount, next_timestep.observation) episode_return += next_timestep.reward[0] episode_timesteps += 1 timestep = next_timestep if i + 1 >= FLAGS.start_training_timesteps: with summary_writer.as_default(): info_dict = model.update_step(replay_buffer_iter) if (i + 1) % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=(i + 1) * step_mult) if (i + 1) % FLAGS.eval_interval == 0: logging.info('Performing policy eval.') average_returns, evaluation_timesteps = evaluation.evaluate( eval_env, model) with results_writer.as_default(): tf.summary.scalar( 'evaluation/returns', average_returns, step=(i + 1) * step_mult) tf.summary.scalar( 'evaluation/length', evaluation_timesteps, step=(i+1) * step_mult) logging.info('Eval at %d: ave returns=%f, ave episode length=%f', (i + 1) * step_mult, average_returns, evaluation_timesteps) if ((i + 1) * step_mult) % 50_000 == 0: model.save_weights( os.path.join(FLAGS.save_dir, 'results', FLAGS.env_name + '__' + str( (i + 1) * step_mult)))
def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) print('Num GPUs Available: ', len(tf.config.list_physical_devices('GPU'))) if FLAGS.env_name.startswith('procgen'): print('Test env: %s' % FLAGS.env_name) _, env_name, train_levels, _ = FLAGS.env_name.split('-') print('Train env: %s' % FLAGS.env_name) env = tf_py_environment.TFPyEnvironment( procgen_wrappers.TFAgentsParallelProcGenEnv( 1, normalize_rewards=False, # no normalization for evaluation env_name=env_name, num_levels=int(train_levels), start_level=0)) env_all = tf_py_environment.TFPyEnvironment( procgen_wrappers.TFAgentsParallelProcGenEnv( 1, normalize_rewards=False, # no normalization for evaluation env_name=env_name, num_levels=0, start_level=0)) if int(train_levels) == 0: train_levels = '200' elif FLAGS.env_name.startswith('pixels-dm'): if 'distractor' in FLAGS.env_name: _, _, domain_name, _, _ = FLAGS.env_name.split('-') else: _, _, domain_name, _ = FLAGS.env_name.split('-') if domain_name in ['cartpole']: FLAGS.set_default('action_repeat', 8) elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']: FLAGS.set_default('action_repeat', 4) elif domain_name in ['finger', 'walker']: FLAGS.set_default('action_repeat', 2) env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack, FLAGS.obs_type) if FLAGS.obs_type == 'pixels': env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack, FLAGS.obs_type) else: _, env = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack, FLAGS.obs_type) if FLAGS.obs_type != 'state': if FLAGS.env_name.startswith('procgen'): bcq = bcq_pixel cql = cql_pixel fisher_brac = fisher_brac_pixel deepmdp = deepmdp_pixel vpn = vpn_pixel cssc = cssc_pixel pse = pse_pixel else: bcq = bcq_state cql = cql_state print('Loading dataset') # Use load_tfrecord_dataset_sequence to load transitions of size k>=2. if FLAGS.numpy_dataset: n_shards = 10 def shard_fn(shard): return ('experiments/' '20210617_0105.dataset_dmc_50k,100k,' '200k_SAC_pixel_numpy/datasets/' '%s__%d__%d__%d.npy' % (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps, shard)) np_observer = tf_utils.NumpyObserver(shard_fn, env) dataset = np_observer.load(n_shards) else: if FLAGS.env_name.startswith('procgen'): if FLAGS.n_step_returns > 0: if FLAGS.max_timesteps == 100_000: dataset_path = ('experiments/' '20210624_2033.dataset_procgen__ppo_pixel/' 'datasets/%s__%d__%d.tfrecord' % (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps)) elif FLAGS.max_timesteps == 3_000_000: if int(train_levels) == 1: print('Using dataset with 1 level') dataset_path = ( 'experiments/' '20210713_1557.dataset_procgen__ppo_pixel_1_level/' 'datasets/%s__%d__%d.tfrecord' % (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps)) elif int(train_levels) == 200: print('Using dataset with 200 levels') # Mixture dataset between 10M,15M,20M and 25M in equal amounts # dataset_path = 'experiments/ # 20210718_1522.dataset_procgen__ppo_pixel_mixture10,15,20,25M/ # datasets/%s__%d__%d.tfrecord'%(FLAGS.env_name, # FLAGS.ckpt_timesteps,FLAGS.max_timesteps) # PPO after 25M steps dataset_path = ( 'experiments/' '20210702_2234.dataset_procgen__ppo_pixel/' 'datasets/%s__%d__%d.tfrecord' % (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps)) elif FLAGS.max_timesteps == 5_000_000: # epsilon-greedy, eps: 0.1->0.001 dataset_path = ( 'experiments/' '20210805_1958.dataset_procgen__ppo_pixel_' 'egreedy_levelIDs/datasets/' '%s__%d__%d.tfrecord*' % (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000)) # Pure greedy (epsilon=0) # dataset_path = ('experiments/' # '20210820_1348.dataset_procgen__ppo_pixel_' # 'egreedy_levelIDs/datasets/' # '%s__%d__%d.tfrecord*' % # (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000)) elif FLAGS.env_name.startswith('pixels-dm'): if 'distractor' in FLAGS.env_name: dataset_path = ( 'experiments/' '20210623_1749.dataset_dmc__sac_pixel/datasets/' '%s__%d__%d.tfrecord' % (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps)) else: if FLAGS.obs_type == 'pixels': dataset_path = ( 'experiments/' '20210612_1644.dataset_dmc_50k,100k,200k_SAC_pixel/' 'datasets/%s__%d__%d.tfrecord' % (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps)) else: dataset_path = ( 'experiments/' '20210621_1436.dataset_dmc__SAC_pixel/datasets/' '%s__%d__%d.tfrecord' % (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps)) shards = tf.io.gfile.glob(dataset_path) shards = [s for s in shards if not s.endswith('.spec')] print('Found %d shards under path %s' % (len(shards), dataset_path)) if FLAGS.n_step_returns > 1: # Load sequences of length N dataset = load_tfrecord_dataset_sequence( shards, buffer_size_per_shard=FLAGS.dataset_size // len(shards), deterministic=False, compress_image=True, seq_len=FLAGS.n_step_returns) # spec=data_spec, dataset = dataset.take(FLAGS.dataset_size).shuffle( buffer_size=FLAGS.batch_size, reshuffle_each_iteration=False).batch( FLAGS.batch_size, drop_remainder=True).prefetch(1).repeat() dataset_iter = iter(dataset) else: dataset_iter = tf_utils.create_data_iterator( ('experiments/20210805' '_1958.dataset_procgen__ppo_pixel_egreedy_' 'levelIDs/datasets/%s__%d__%d.tfrecord.shard-*-of-*' % (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000)), FLAGS.batch_size, shuffle_buffer_size=FLAGS.batch_size, obs_to_float=False) tf.random.set_seed(FLAGS.seed) hparam_str = utils.make_hparam_string( FLAGS.xm_parameters, algo_name=FLAGS.algo_name, seed=FLAGS.seed, task_name=FLAGS.env_name, ckpt_timesteps=FLAGS.ckpt_timesteps, rep_learn_keywords=FLAGS.rep_learn_keywords) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) result_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) pretrain = (FLAGS.pretrain > 0) if FLAGS.env_name.startswith('procgen'): # disable entropy reg for discrete spaces action_dim = env.action_spec().maximum.item() + 1 else: action_dim = env.action_spec().shape[0] if 'cql' in FLAGS.algo_name: model = cql.CQL(env.observation_spec(), env.action_spec(), reg=FLAGS.f_reg, target_entropy=-action_dim, num_augmentations=FLAGS.num_data_augs, rep_learn_keywords=FLAGS.rep_learn_keywords, batch_size=FLAGS.batch_size) elif 'bcq' in FLAGS.algo_name: model = bcq.BCQ(env.observation_spec(), env.action_spec(), num_augmentations=FLAGS.num_data_augs) elif 'fbrac' in FLAGS.algo_name: model = fisher_brac.FBRAC(env.observation_spec(), env.action_spec(), target_entropy=-action_dim, f_reg=FLAGS.f_reg, reward_bonus=FLAGS.reward_bonus, num_augmentations=FLAGS.num_data_augs, env_name=FLAGS.env_name, batch_size=FLAGS.batch_size) elif 'ours' in FLAGS.algo_name: model = ours.OURS(env.observation_spec(), env.action_spec(), target_entropy=-action_dim, f_reg=FLAGS.f_reg, reward_bonus=FLAGS.reward_bonus, num_augmentations=FLAGS.num_data_augs, env_name=FLAGS.env_name, rep_learn_keywords=FLAGS.rep_learn_keywords, batch_size=FLAGS.batch_size, n_quantiles=FLAGS.n_quantiles, temp=FLAGS.temp, num_training_levels=train_levels) bc_pretraining_steps = FLAGS.pretrain if pretrain: model_save_path = os.path.join(FLAGS.save_dir, 'weights', hparam_str) checkpoint = tf.train.Checkpoint(**model.model_dict) tf_step_counter = tf.Variable(0, dtype=tf.int32) manager = tf.train.CheckpointManager( checkpoint, directory=model_save_path, max_to_keep=1, checkpoint_interval=FLAGS.save_interval, step_counter=tf_step_counter) # Load the checkpoint in case it exists state = manager.restore_or_initialize() if state is not None: # loaded variables from checkpoint folder timesteps_already_done = int( re.findall('ckpt-([0-9]*)', state)[0]) #* FLAGS.save_interval print('Loaded model from timestep %d' % timesteps_already_done) else: print('Training from scratch') timesteps_already_done = 0 tf_step_counter.assign(timesteps_already_done) print('Pretraining') for i in tqdm.tqdm(range(bc_pretraining_steps)): info_dict = model.update_step(dataset_iter, train_target='encoder') # (quantile_states, quantile_bins) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): v = tf.reduce_mean(v) tf.summary.scalar(f'pretrain/{k}', v, step=i) tf_step_counter.assign(i) manager.save(checkpoint_number=i) elif 'bc' in FLAGS.algo_name: model = bc_pixel.BehavioralCloning( env.observation_spec(), env.action_spec(), mixture=False, encoder=None, num_augmentations=FLAGS.num_data_augs, rep_learn_keywords=FLAGS.rep_learn_keywords, env_name=FLAGS.env_name, batch_size=FLAGS.batch_size) elif 'deepmdp' in FLAGS.algo_name: model = deepmdp.DeepMdpLearner( env.observation_spec(), env.action_spec(), embedding_dim=512, num_distributions=1, sequence_length=2, learning_rate=3e-4, num_augmentations=FLAGS.num_data_augs, rep_learn_keywords=FLAGS.rep_learn_keywords, batch_size=FLAGS.batch_size) elif 'vpn' in FLAGS.algo_name: model = vpn.ValuePredictionNetworkLearner( env.observation_spec(), env.action_spec(), embedding_dim=512, learning_rate=3e-4, num_augmentations=FLAGS.num_data_augs, rep_learn_keywords=FLAGS.rep_learn_keywords, batch_size=FLAGS.batch_size) elif 'cssc' in FLAGS.algo_name: model = cssc.CSSC(env.observation_spec(), env.action_spec(), embedding_dim=512, actor_lr=3e-4, critic_lr=3e-4, num_augmentations=FLAGS.num_data_augs, rep_learn_keywords=FLAGS.rep_learn_keywords, batch_size=FLAGS.batch_size) elif 'pse' in FLAGS.algo_name: model = pse.PSE(env.observation_spec(), env.action_spec(), embedding_dim=512, actor_lr=3e-4, critic_lr=3e-4, num_augmentations=FLAGS.num_data_augs, rep_learn_keywords=FLAGS.rep_learn_keywords, batch_size=FLAGS.batch_size, temperature=FLAGS.temp) bc_pretraining_steps = FLAGS.pretrain if pretrain: print('Pretraining') for i in tqdm.tqdm(range(bc_pretraining_steps)): info_dict = model.update_step(dataset_iter, train_target='encoder') if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): v = tf.reduce_mean(v) tf.summary.scalar(f'pretrain/{k}', v, step=i) if 'fbrac' in FLAGS.algo_name or FLAGS.algo_name == 'bc': # Either load the online policy: if FLAGS.load_bc and FLAGS.env_name.startswith('procgen'): env_id = [ i for i, name in enumerate(PROCGEN_ENVS) if name == env_name ][0] + 1 # map env string to digit [1,16] if FLAGS.ckpt_timesteps == 10_000_000: ckpt_iter = '0000020480' elif FLAGS.ckpt_timesteps == 25_000_000: ckpt_iter = '0000051200' policy_weights_dir = ( 'ppo_darts/' '2021-06-22-16-36-54/%d/policies/checkpoints/' 'policy_checkpoint_%s/' % (env_id, ckpt_iter)) policy_def_dir = ('ppo_darts/' '2021-06-22-16-36-54/%d/policies/policy/' % (env_id)) bc = py_tf_eager_policy.SavedModelPyTFEagerPolicy( policy_def_dir, time_step_spec=env._time_step_spec, # pylint: disable=protected-access action_spec=env._action_spec, # pylint: disable=protected-access policy_state_spec=env._observation_spec, # pylint: disable=protected-access info_spec=tf.TensorSpec(shape=(None, )), load_specs_from_pbtxt=False) bc.update_from_checkpoint(policy_weights_dir) model.bc.policy = tf_utils.TfAgentsPolicy(bc) else: if FLAGS.algo_name == 'fbrac': bc_pretraining_steps = 100_000 elif FLAGS.algo_name == 'bc': bc_pretraining_steps = 1_000_000 if 'fbrac' in FLAGS.algo_name: bc = model.bc else: bc = model for i in tqdm.tqdm(range(bc_pretraining_steps)): info_dict = bc.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): v = tf.reduce_mean(v) tf.summary.scalar(f'bc/{k}', v, step=i) if FLAGS.algo_name == 'bc': if (i + 1) % FLAGS.eval_interval == 0: average_returns, average_length = evaluation.evaluate( env, bc) # (FLAGS.env_name.startswith('procgen')) average_returns_all, average_length_all = evaluation.evaluate( env_all, bc) with result_writer.as_default(): tf.summary.scalar('evaluation/returns', average_returns, step=i + 1) tf.summary.scalar('evaluation/length', average_length, step=i + 1) tf.summary.scalar('evaluation/returns-all', average_returns_all, step=i + 1) tf.summary.scalar('evaluation/length-all', average_length_all, step=i + 1) if FLAGS.algo_name == 'bc': exit() if not (FLAGS.algo_name == 'ours' and pretrain): model_save_path = os.path.join(FLAGS.save_dir, 'weights', hparam_str) checkpoint = tf.train.Checkpoint(**model.model_dict) tf_step_counter = tf.Variable(0, dtype=tf.int32) manager = tf.train.CheckpointManager( checkpoint, directory=model_save_path, max_to_keep=1, checkpoint_interval=FLAGS.save_interval, step_counter=tf_step_counter) # Load the checkpoint in case it exists weights_path = tf.io.gfile.glob(model_save_path + '/ckpt-*.index') key_fn = lambda x: int(re.findall(r'(\d+)', x)[-1]) weights_path.sort(key=key_fn) if weights_path: weights_path = weights_path[-1] # take most recent state = manager.restore_or_initialize() # restore(weights_path) if state is not None: # loaded variables from checkpoint folder timesteps_already_done = int( re.findall('ckpt-([0-9]*)', state)[0]) #* FLAGS.save_interval print('Loaded model from timestep %d' % timesteps_already_done) else: print('Training from scratch') timesteps_already_done = 0 tf_step_counter.assign(timesteps_already_done) for i in tqdm.tqdm(range(timesteps_already_done, FLAGS.num_updates)): with summary_writer.as_default(): info_dict = model.update_step( dataset_iter, train_target='rl' if pretrain else 'both') if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): v = tf.reduce_mean(v) tf.summary.scalar(f'training/{k}', v, step=i) if (i + 1) % FLAGS.eval_interval == 0: average_returns, average_length = evaluation.evaluate(env, model) average_returns_all, average_length_all = evaluation.evaluate( env_all, model) with result_writer.as_default(): tf.summary.scalar('evaluation/returns-200', average_returns, step=i + 1) tf.summary.scalar('evaluation/length-200', average_length, step=i + 1) tf.summary.scalar('evaluation/returns-all', average_returns_all, step=i + 1) tf.summary.scalar('evaluation/length-all', average_length_all, step=i + 1) tf_step_counter.assign(i) manager.save(checkpoint_number=i)
def main(_): if FLAGS.eager: tf.config.experimental_run_functions_eagerly(FLAGS.eager) tf.random.set_seed(FLAGS.seed) # np.random.seed(FLAGS.seed) # random.seed(FLAGS.seed) if 'procgen' in FLAGS.env_name: _, env_name, train_levels, _ = FLAGS.env_name.split('-') env = procgen_wrappers.TFAgentsParallelProcGenEnv( 1, normalize_rewards=False, env_name=env_name, num_levels=int(train_levels), start_level=0) elif FLAGS.env_name.startswith('pixels-dm'): if 'distractor' in FLAGS.env_name: _, _, domain_name, _, _ = FLAGS.env_name.split('-') else: _, _, domain_name, _ = FLAGS.env_name.split('-') if domain_name in ['cartpole']: FLAGS.set_default('action_repeat', 8) elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']: FLAGS.set_default('action_repeat', 4) elif domain_name in ['finger', 'walker']: FLAGS.set_default('action_repeat', 2) env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat, FLAGS.frame_stack, FLAGS.obs_type) hparam_str = utils.make_hparam_string(FLAGS.xm_parameters, algo_name=FLAGS.algo_name, seed=FLAGS.seed, task_name=FLAGS.env_name, ckpt_timesteps=FLAGS.ckpt_timesteps) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) if FLAGS.env_name.startswith('procgen'): # map env string to digit [1,16] env_id = [ i for i, name in enumerate(PROCGEN_ENVS) if name == env_name ][0] + 1 if FLAGS.ckpt_timesteps == 10_000_000: ckpt_iter = '0000020480' elif FLAGS.ckpt_timesteps == 25_000_000: ckpt_iter = '0000051200' policy_weights_dir = ('ppo_darts/' '2021-06-22-16-36-54/%d/policies/checkpoints/' 'policy_checkpoint_%s/' % (env_id, ckpt_iter)) policy_def_dir = ('ppo_darts/' '2021-06-22-16-36-54/%d/policies/policy/' % (env_id)) model = py_tf_eager_policy.SavedModelPyTFEagerPolicy( policy_def_dir, time_step_spec=env._time_step_spec, # pylint: disable=protected-access action_spec=env._action_spec, # pylint: disable=protected-access policy_state_spec=env._observation_spec, # pylint: disable=protected-access info_spec=tf.TensorSpec(shape=(None, )), load_specs_from_pbtxt=False) model.update_from_checkpoint(policy_weights_dir) model = TfAgentsPolicy(model) else: if 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG(env.observation_spec(), env.action_spec(), cross_norm='crossnorm' in FLAGS.algo_name) elif 'crr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='exp_mean') elif 'sac_v1' in FLAGS.algo_name: model = sac_v1.SAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'asac' in FLAGS.algo_name: model = asac.ASAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], cross_norm='crossnorm' in FLAGS.algo_name, pcl_actor_update='pc' in FLAGS.algo_name) elif 'pcl' in FLAGS.algo_name: model = pcl.PCL(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) if 'distractor' in FLAGS.env_name: ckpt_path = os.path.join( ('experiments/20210622_2023.policy_weights_sac' '_1M_dmc_distractor_hard_pixel/'), 'results', FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps)) else: ckpt_path = os.path.join( ('experiments/20210607_2023.' 'policy_weights_dmc_1M_SAC_pixel'), 'results', FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps)) model.load_weights(ckpt_path) print('Loaded model weights') with summary_writer.as_default(): env = procgen_wrappers.TFAgentsParallelProcGenEnv( 1, normalize_rewards=False, env_name=env_name, num_levels=0, start_level=0) (avg_returns, avg_len) = evaluation.evaluate(env, model, num_episodes=100, return_distributions=False) tf.summary.scalar('evaluation/returns-all', avg_returns, step=0) tf.summary.scalar('evaluation/length-all', avg_len, step=0)
def main(_): tf.config.experimental_run_functions_eagerly(FLAGS.eager) gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset( task_name=FLAGS.task_name, batch_size=FLAGS.batch_size) env = gym_wrapper.GymWrapper(gym_env) env = tf_py_environment.TFPyEnvironment(env) dataset_iter = iter(dataset) tf.random.set_seed(FLAGS.seed) hparam_str = utils.make_hparam_string(FLAGS.xm_parameters, algo_name=FLAGS.algo_name, seed=FLAGS.seed, task_name=FLAGS.task_name, data_name=FLAGS.data_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) result_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) if FLAGS.algo_name == 'bc': model = behavioral_cloning.BehavioralCloning(env.observation_spec(), env.action_spec()) elif FLAGS.algo_name == 'bc_mix': model = behavioral_cloning.BehavioralCloning(env.observation_spec(), env.action_spec(), mixture=True) elif 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG(env.observation_spec(), env.action_spec()) elif 'crr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR(env.observation_spec(), env.action_spec(), f='exp_mean') elif 'bcq' in FLAGS.algo_name: model = bcq.BCQ(env.observation_spec(), env.action_spec()) elif 'asac' in FLAGS.algo_name: model = asac.ASAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'cql' in FLAGS.algo_name: model = cql.CQL(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'brac' in FLAGS.algo_name: if 'fbrac' in FLAGS.algo_name: model = fisher_brac.FBRAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], f_reg=FLAGS.f_reg, reward_bonus=FLAGS.reward_bonus) else: model = brac.BRAC(env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) model_folder = os.path.join( FLAGS.save_dir, 'models', f'{FLAGS.task_name}_{FLAGS.data_name}_{FLAGS.seed}') if not tf.gfile.io.isdir(model_folder): bc_pretraining_steps = 1_000_000 for i in tqdm.tqdm(range(bc_pretraining_steps)): info_dict = model.bc.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i - bc_pretraining_steps) # model.bc.policy.save_weights(os.path.join(model_folder, 'model')) else: model.bc.policy.load_weights(os.path.join(model_folder, 'model')) for i in tqdm.tqdm(range(FLAGS.num_updates)): with summary_writer.as_default(): info_dict = model.update_step(dataset_iter) if i % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=i) if (i + 1) % FLAGS.eval_interval == 0: average_returns, average_length = evaluation.evaluate(env, model) if FLAGS.data_name is None: average_returns = gym_env.get_normalized_score( average_returns) * 100.0 with result_writer.as_default(): tf.summary.scalar('evaluation/returns', average_returns, step=i + 1) tf.summary.scalar('evaluation/length', average_length, step=i + 1)