Esempio n. 1
0
def main(_):
    # Create an environment and create the spec.
    environment, environment_spec = _build_environment(
        FLAGS.environment_name, max_steps=FLAGS.max_steps_per_episode)

    if FLAGS.model_name:
        loaded_network = load_wb_model(FLAGS.model_name, FLAGS.model_tag)

        if FLAGS.stochastic:
            head = networks.StochasticSamplingHead()
        else:
            head = lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon
                                                 ).sample()

        policy_network = snt.Sequential([
            loaded_network,
            head,
        ])
        actor = actors.FeedForwardActor(policy_network)

    else:
        actor = RandomActor(environment_spec)

    recorder = DemonstrationRecorder(environment, actor)

    recorder.collect_n_episodes(FLAGS.n_episodes)
    recorder.make_tf_dataset()
    recorder.save(FLAGS.save_dir)
Esempio n. 2
0
    def actor(
        self,
        replay: reverb.Client,
        variable_source: acme.VariableSource,
        counter: counting.Counter,
    ) -> acme.EnvironmentLoop:
        """The actor process."""

        action_spec = self._environment_spec.actions
        observation_spec = self._environment_spec.observations

        # Create environment and target networks to act with.
        environment = self._environment_factory(False)
        agent_networks = self._network_factory(action_spec,
                                               self._num_critic_heads)

        # Make sure observation network is defined.
        observation_network = agent_networks.get('observation', tf.identity)

        # Create a stochastic behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            agent_networks['policy'],
            networks.StochasticSamplingHead(),
        ])

        # Ensure network variables are created.
        tf2_utils.create_variables(behavior_network, [observation_spec])
        policy_variables = {'policy': behavior_network.variables}

        # Create the variable client responsible for keeping the actor up-to-date.
        variable_client = tf2_variable_utils.VariableClient(variable_source,
                                                            policy_variables,
                                                            update_period=1000)

        # Make sure not to use a random policy after checkpoint restoration by
        # assigning variables before running the environment loop.
        variable_client.update_and_wait()

        # Component to add things into replay.
        adder = adders.NStepTransitionAdder(
            client=replay,
            n_step=self._n_step,
            max_in_flight_items=self._max_in_flight_items,
            discount=self._additional_discount)

        # Create the agent.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder,
                                        variable_client=variable_client)

        # Create logger and counter; actors will not spam bigtable.
        counter = counting.Counter(counter, 'actor')
        logger = loggers.make_default_logger('actor',
                                             save_data=False,
                                             time_delta=self._log_every,
                                             steps_key='actor_steps')

        # Create the run loop and return it.
        return acme.EnvironmentLoop(environment, actor, counter, logger)
Esempio n. 3
0
    def __init__(self,
                 policy_network: snt.RNNCore,
                 critic_network: networks.CriticDeepRNN,
                 target_policy_network: snt.RNNCore,
                 target_critic_network: networks.CriticDeepRNN,
                 dataset: tf.data.Dataset,
                 accelerator_strategy: Optional[tf.distribute.Strategy] = None,
                 behavior_network: Optional[snt.Module] = None,
                 cwp_network: Optional[snt.Module] = None,
                 policy_optimizer: Optional[snt.Optimizer] = None,
                 critic_optimizer: Optional[snt.Optimizer] = None,
                 discount: float = 0.99,
                 target_update_period: int = 100,
                 num_action_samples_td_learning: int = 1,
                 num_action_samples_policy_weight: int = 4,
                 baseline_reduce_function: str = 'mean',
                 clipping: bool = True,
                 policy_improvement_modes: str = 'exp',
                 ratio_upper_bound: float = 20.,
                 beta: float = 1.0,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 checkpoint: bool = False):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_reverb_dataset` documentation).
      accelerator_strategy: the strategy used to distribute computation,
        whether on a single, or multiple, GPU or TPU; as supported by
        tf.distribute.
      behavior_network: The network to snapshot under `policy` name. If None,
        snapshots `policy_network` instead.
      cwp_network: CWP network to snapshot: samples actions
        from the policy and weighs them with the critic, then returns the action
        by sampling from the softmax distribution using critic values as logits.
        Used only for snapshotting, not training.
      policy_optimizer: the optimizer to be applied to the policy loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      num_action_samples_td_learning: number of action samples to use to
        estimate expected value of the critic loss w.r.t. stochastic policy.
      num_action_samples_policy_weight: number of action samples to use to
        estimate the advantage function for the CRR weighting of the policy
        loss.
      baseline_reduce_function: one of 'mean', 'max', 'min'. Way of aggregating
        values from `num_action_samples` estimates of the value function.
      clipping: whether to clip gradients by global norm.
      policy_improvement_modes: one of 'exp', 'binary', 'all'. CRR mode which
        determines how the advantage function is processed before being
        multiplied by the policy loss.
      ratio_upper_bound: if policy_improvement_modes is 'exp', determines
        the upper bound of the weight (i.e. the weight is
          min(exp(advantage / beta), upper_bound)
        ).
      beta: if policy_improvement_modes is 'exp', determines the beta (see
        above).
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        if accelerator_strategy is None:
            accelerator_strategy = snt.distribute.Replicator()
        self._accelerator_strategy = accelerator_strategy
        self._policy_improvement_modes = policy_improvement_modes
        self._ratio_upper_bound = ratio_upper_bound
        self._num_action_samples_td_learning = num_action_samples_td_learning
        self._num_action_samples_policy_weight = num_action_samples_policy_weight
        self._baseline_reduce_function = baseline_reduce_function
        self._beta = beta

        # When running on TPUs we have to know the amount of memory required (and
        # thus the sequence length) at the graph compilation stage. At the moment,
        # the only way to get it is to sample from the dataset, since the dataset
        # does not have any metadata, see b/160672927 to track this upcoming
        # feature.
        sample = next(dataset.as_numpy_iterator())
        self._sequence_length = sample.action.shape[1]

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)
        self._discount = discount
        self._clipping = clipping

        self._target_update_period = target_update_period

        with self._accelerator_strategy.scope():
            # Necessary to track when to update target networks.
            self._num_steps = tf.Variable(0, dtype=tf.int32)

            # (Maybe) distributing the dataset across multiple accelerators.
            distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset(
                dataset)
            self._iterator = iter(distributed_dataset)

            # Create the optimizers.
            self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(
                1e-4)
            self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(
                1e-4)

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Expose the variables.
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': self._target_policy_network.variables,
        }

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None
        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'num_steps': self._num_steps,
                },
                time_delta_minutes=30.)

            raw_policy = snt.DeepRNN(
                [policy_network,
                 networks.StochasticSamplingHead()])
            critic_mean = networks.CriticDeepRNN(
                [critic_network, networks.StochasticMeanHead()])
            objects_to_save = {
                'raw_policy': raw_policy,
                'critic': critic_mean,
            }
            if behavior_network is not None:
                objects_to_save['policy'] = behavior_network
            if cwp_network is not None:
                objects_to_save['cwp_policy'] = cwp_network
            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save=objects_to_save, time_delta_minutes=30)
        # Timestamp to keep track of the wall time.
        self._walltime_timestamp = time.time()
Esempio n. 4
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy_network: snt.Module,
        critic_network: snt.Module,
        observation_network: types.TensorTransformation = tf.identity,
        discount: float = 0.99,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_policy_update_period: int = 100,
        target_critic_update_period: int = 100,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        policy_loss_module: snt.Module = None,
        policy_optimizer: snt.Optimizer = None,
        critic_optimizer: snt.Optimizer = None,
        n_step: int = 5,
        num_samples: int = 20,
        clipping: bool = True,
        logger: loggers.Logger = None,
        counter: counting.Counter = None,
        checkpoint: bool = True,
        replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
    ):
        """Initialize the agent.

    Args:
      environment_spec: description of the actions, observations, etc.
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      observation_network: optional network to transform the observations before
        they are fed into any network.
      discount: discount to use for TD updates.
      batch_size: batch size for updates.
      prefetch_size: size to prefetch from replay.
      target_policy_update_period: number of updates to perform before updating
        the target policy network.
      target_critic_update_period: number of updates to perform before updating
        the target critic network.
      min_replay_size: minimum replay size before updating.
      max_replay_size: maximum replay size.
      samples_per_insert: number of samples to take from replay for every insert
        that is made.
      policy_loss_module: configured MPO loss function for the policy
        optimization; defaults to sensible values on the control suite.
        See `acme/tf/losses/mpo.py` for more details.
      policy_optimizer: optimizer to be used on the policy.
      critic_optimizer: optimizer to be used on the critic.
      n_step: number of steps to squash into a single transition.
      num_samples: number of actions to sample when doing a Monte Carlo
        integration with respect to the policy.
      clipping: whether to clip gradients by global norm.
      logger: logging object used to write to logs.
      counter: counter object used to keep track of steps.
      checkpoint: boolean indicating whether to checkpoint the learner.
      replay_table_name: string indicating what name to give the replay table.
    """

        # Create a replay server to add data to.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
            signature=adders.NStepTransitionAdder.signature(environment_spec))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset object to learn from.
        dataset = datasets.make_reverb_dataset(
            table=replay_table_name,
            client=reverb.TFClient(address),
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            environment_spec=environment_spec,
            transition_adder=True)

        # Make sure observation network is a Sonnet Module.
        observation_network = tf2_utils.to_sonnet_module(observation_network)

        # Create target networks before creating online/target network variables.
        target_policy_network = copy.deepcopy(policy_network)
        target_critic_network = copy.deepcopy(critic_network)
        target_observation_network = copy.deepcopy(observation_network)

        # Get observation and action specs.
        act_spec = environment_spec.actions
        obs_spec = environment_spec.observations
        emb_spec = tf2_utils.create_variables(observation_network, [obs_spec])

        # Create the behavior policy.
        behavior_network = snt.Sequential([
            observation_network,
            policy_network,
            networks.StochasticSamplingHead(),
        ])

        # Create variables.
        tf2_utils.create_variables(policy_network, [emb_spec])
        tf2_utils.create_variables(critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_policy_network, [emb_spec])
        tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec])
        tf2_utils.create_variables(target_observation_network, [obs_spec])

        # Create the actor which defines how we take actions.
        actor = actors.FeedForwardActor(policy_network=behavior_network,
                                        adder=adder)

        # Create optimizers.
        policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)
        critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)

        # The learner updates the parameters (and initializes them).
        learner = learning.MPOLearner(
            policy_network=policy_network,
            critic_network=critic_network,
            observation_network=observation_network,
            target_policy_network=target_policy_network,
            target_critic_network=target_critic_network,
            target_observation_network=target_observation_network,
            policy_loss_module=policy_loss_module,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=clipping,
            discount=discount,
            num_samples=num_samples,
            target_policy_update_period=target_policy_update_period,
            target_critic_update_period=target_critic_update_period,
            dataset=dataset,
            logger=logger,
            counter=counter,
            checkpoint=checkpoint)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Esempio n. 5
0
def main(_):
    # TODO(yutian): Create environment.
    # # Create an environment and grab the spec.
    # raw_environment = bsuite.load_and_record_to_csv(
    #     bsuite_id=FLAGS.bsuite_id,
    #     results_dir=FLAGS.results_dir,
    #     overwrite=FLAGS.overwrite,
    # )
    # environment = single_precision.SinglePrecisionWrapper(raw_environment)
    # environment_spec = specs.make_environment_spec(environment)

    # TODO(yutian): Create dataset.
    # Build the dataset.
    # if hasattr(raw_environment, 'raw_env'):
    #   raw_environment = raw_environment.raw_env
    #
    # batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # # Combine with demonstration dataset.
    # transition = functools.partial(
    #     _n_step_transition_from_episode, n_step=1, additional_discount=1.)
    #
    # dataset = batch_dataset.map(transition)
    #
    # # Batch and prefetch.
    # dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    # dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    networks = make_networks(environment_spec.actions)
    treatment_net = networks['treatment_net']
    instrumental_net = networks['instrumental_net']
    policy_net = networks['policy_net']

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_net = snt.Sequential([
        policy_net,
        # Sample actions.
        acme_nets.StochasticSamplingHead()
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_net, [environment_spec.observations])
    # TODO(liyuan): set the proper input spec using environment_spec.observations
    # and environment_spec.actions.
    tf2_utils.create_variables(treatment_net, [environment_spec.observations])
    tf2_utils.create_variables(
        instrumental_net,
        [environment_spec.observations, environment_spec.actions])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluator_net = actors.FeedForwardActor(evaluator_net)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator_net,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.DFIVLearner(
        treatment_net=treatment_net,
        instrumental_net=instrumental_net,
        policy_net=policy_net,
        treatment_learning_rate=FLAGS.treatment_learning_rate,
        instrumental_learning_rate=FLAGS.instrumental_learning_rate,
        policy_learning_rate=FLAGS.policy_learning_rate,
        dataset=dataset,
        counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)