def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-5,
                 critic_lr=1e-3,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
    """
        assert len(observation_spec.shape) == 1
        state_dim = observation_spec.shape[0]

        self.actor = policies.DiagGuassianPolicy(state_dim, action_spec)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.log_alpha = tf.Variable(tf.math.log(0.3), trainable=True)

        self.target_entropy = target_entropy
        self.discount = discount
        self.tau = tau

        self.bc = behavioral_cloning.BehavioralCloning(observation_spec,
                                                       action_spec,
                                                       mixture=True)

        action_dim = action_spec.shape[0]
        self.critic = critic.Critic(state_dim, action_dim)
        self.critic_target = critic.Critic(state_dim, action_dim)
        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_lr=3e-4,
                 critic_lr=3e-4,
                 alpha_lr=3e-4,
                 discount=0.99,
                 tau=0.005,
                 target_entropy=0.0,
                 f_reg=1.0,
                 reward_bonus=5.0,
                 num_augmentations=1,
                 env_name='',
                 batch_size=256):
        """Creates networks.

    Args:
      observation_spec: environment observation spec.
      action_spec: Action spec.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_lr: Temperature learning rate.
      discount: MDP discount.
      tau: Soft target update parameter.
      target_entropy: Target entropy.
      f_reg: Critic regularization weight.
      reward_bonus: Bonus added to the rewards.
      num_augmentations: Number of random crops
      env_name: Env name
      batch_size: batch size
    """
        del num_augmentations, env_name
        assert len(observation_spec.shape) == 1
        state_dim = observation_spec.shape[0]
        self.batch_size = batch_size

        hidden_dims = (256, 256, 256)
        self.actor = policies.DiagGuassianPolicy(state_dim,
                                                 action_spec,
                                                 hidden_dims=hidden_dims)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)

        self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True)
        self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr)

        self.target_entropy = target_entropy
        self.discount = discount
        self.tau = tau

        self.bc = behavioral_cloning.BehavioralCloning(observation_spec,
                                                       action_spec,
                                                       mixture=True)

        action_dim = action_spec.shape[0]
        self.critic = critic.Critic(state_dim,
                                    action_dim,
                                    hidden_dims=hidden_dims)
        self.critic_target = critic.Critic(state_dim,
                                           action_dim,
                                           hidden_dims=hidden_dims)
        critic.soft_update(self.critic, self.critic_target, tau=1.0)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)

        self.f_reg = f_reg
        self.reward_bonus = reward_bonus

        self.model_dict = {
            'critic': self.critic,
            'actor': self.actor,
            'critic_target': self.critic_target,
            'actor_optimizer': self.actor_optimizer,
            'critic_optimizer': self.critic_optimizer,
            'alpha_optimizer': self.alpha_optimizer
        }
def main(_):
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset(
        task_name=FLAGS.task_name, batch_size=FLAGS.batch_size)

    env = gym_wrapper.GymWrapper(gym_env)
    env = tf_py_environment.TFPyEnvironment(env)

    dataset_iter = iter(dataset)

    tf.random.set_seed(FLAGS.seed)

    hparam_str = utils.make_hparam_string(FLAGS.xm_parameters,
                                          algo_name=FLAGS.algo_name,
                                          seed=FLAGS.seed,
                                          task_name=FLAGS.task_name,
                                          data_name=FLAGS.data_name)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    result_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'results', hparam_str))

    if FLAGS.algo_name == 'bc':
        model = behavioral_cloning.BehavioralCloning(env.observation_spec(),
                                                     env.action_spec())
    elif FLAGS.algo_name == 'bc_mix':
        model = behavioral_cloning.BehavioralCloning(env.observation_spec(),
                                                     env.action_spec(),
                                                     mixture=True)
    elif 'ddpg' in FLAGS.algo_name:
        model = ddpg.DDPG(env.observation_spec(), env.action_spec())
    elif 'crr' in FLAGS.algo_name:
        model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max')
    elif 'awr' in FLAGS.algo_name:
        model = awr.AWR(env.observation_spec(),
                        env.action_spec(),
                        f='exp_mean')
    elif 'bcq' in FLAGS.algo_name:
        model = bcq.BCQ(env.observation_spec(), env.action_spec())
    elif 'asac' in FLAGS.algo_name:
        model = asac.ASAC(env.observation_spec(),
                          env.action_spec(),
                          target_entropy=-env.action_spec().shape[0])
    elif 'sac' in FLAGS.algo_name:
        model = sac.SAC(env.observation_spec(),
                        env.action_spec(),
                        target_entropy=-env.action_spec().shape[0])
    elif 'cql' in FLAGS.algo_name:
        model = cql.CQL(env.observation_spec(),
                        env.action_spec(),
                        target_entropy=-env.action_spec().shape[0])
    elif 'brac' in FLAGS.algo_name:
        if 'fbrac' in FLAGS.algo_name:
            model = fisher_brac.FBRAC(
                env.observation_spec(),
                env.action_spec(),
                target_entropy=-env.action_spec().shape[0],
                f_reg=FLAGS.f_reg,
                reward_bonus=FLAGS.reward_bonus)
        else:
            model = brac.BRAC(env.observation_spec(),
                              env.action_spec(),
                              target_entropy=-env.action_spec().shape[0])

        model_folder = os.path.join(
            FLAGS.save_dir, 'models',
            f'{FLAGS.task_name}_{FLAGS.data_name}_{FLAGS.seed}')
        if not tf.gfile.io.isdir(model_folder):
            bc_pretraining_steps = 1_000_000
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.bc.update_step(dataset_iter)

                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            tf.summary.scalar(f'training/{k}',
                                              v,
                                              step=i - bc_pretraining_steps)
            # model.bc.policy.save_weights(os.path.join(model_folder, 'model'))
        else:
            model.bc.policy.load_weights(os.path.join(model_folder, 'model'))

    for i in tqdm.tqdm(range(FLAGS.num_updates)):
        with summary_writer.as_default():
            info_dict = model.update_step(dataset_iter)

        if i % FLAGS.log_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    tf.summary.scalar(f'training/{k}', v, step=i)

        if (i + 1) % FLAGS.eval_interval == 0:
            average_returns, average_length = evaluation.evaluate(env, model)
            if FLAGS.data_name is None:
                average_returns = gym_env.get_normalized_score(
                    average_returns) * 100.0

            with result_writer.as_default():
                tf.summary.scalar('evaluation/returns',
                                  average_returns,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length',
                                  average_length,
                                  step=i + 1)