Esempio n. 1
0
  def evaluator(
      self,
      variable_source: acme.VariableSource,
      counter: counting.Counter,
  ):
    """The evaluation process."""

    action_spec = self._environment_spec.actions
    observation_spec = self._environment_spec.observations

    # Create environment and target networks to act with.
    environment = self._environment_factory(True)
    agent_networks = self._network_factory(action_spec)

    # Make sure observation network is defined.
    observation_network = agent_networks.get('observation', tf.identity)

    # Create a stochastic behavior policy.
    evaluator_network = snt.Sequential([
        observation_network,
        agent_networks['policy'],
        networks.StochasticMeanHead(),
    ])

    # Ensure network variables are created.
    tf2_utils.create_variables(evaluator_network, [observation_spec])
    policy_variables = {'policy': evaluator_network.variables}

    # Create the variable client responsible for keeping the actor up-to-date.
    variable_client = tf2_variable_utils.VariableClient(
        variable_source,
        policy_variables,
        update_period=self._variable_update_period)

    # Make sure not to evaluate a random actor by assigning variables before
    # running the environment loop.
    variable_client.update_and_wait()

    # Create the agent.
    evaluator = actors.FeedForwardActor(
        policy_network=evaluator_network, variable_client=variable_client)

    # Create logger and counter.
    counter = counting.Counter(counter, 'evaluator')
    logger = loggers.make_default_logger(
        'evaluator', time_delta=self._log_every, steps_key='evaluator_steps')
    observers = self._make_observers() if self._make_observers else ()

    # Create the run loop and return it.
    return acme.EnvironmentLoop(
        environment,
        evaluator,
        counter,
        logger,
        observers=observers)
Esempio n. 2
0
    def evaluator(
        self,
        variable_source: acme.VariableSource,
        counter: counting.Counter,
    ):
        """The evaluation process."""

        # Create environment and target networks to act with.
        environment = self._environment_factory(True)
        agent_networks = self._network_factory(self._environment_spec)

        # Create a stochastic behavior policy.
        evaluator_network = snt.Sequential([
            agent_networks['observation'],
            agent_networks['policy'],
            networks.StochasticMeanHead(),
        ])

        # Create the variable client responsible for keeping the actor up-to-date.
        variable_client = tf2_variable_utils.VariableClient(
            variable_source,
            variables={'policy': evaluator_network.variables},
            update_period=1000)

        # Make sure not to evaluate a random actor by assigning variables before
        # running the environment loop.
        variable_client.update_and_wait()

        # Create the agent.
        evaluator = actors.FeedForwardActor(policy_network=evaluator_network,
                                            variable_client=variable_client)

        # Create logger and counter.
        counter = counting.Counter(counter, 'evaluator')
        logger = loggers.make_default_logger('evaluator',
                                             time_delta=self._log_every)

        # Create the run loop and return it.
        return acme.EnvironmentLoop(environment, evaluator, counter, logger)
Esempio n. 3
0
def make_networks(
        action_spec: specs.BoundedArray,
        policy_layer_sizes: Sequence[int] = (50, 50),
        critic_layer_sizes: Sequence[int] = (50, 50),
):
    """Creates networks used by the agent."""

    num_dimensions = np.prod(action_spec.shape, dtype=int)

    observation_network = tf2_utils.batch_concat
    policy_network = snt.Sequential([
        networks.LayerNormMLP(policy_layer_sizes, activate_final=True),
        networks.MultivariateNormalDiagHead(num_dimensions,
                                            tanh_mean=True,
                                            init_scale=0.3,
                                            fixed_scale=True,
                                            use_tfd_independent=False)
    ])
    evaluator_network = snt.Sequential([
        observation_network,
        policy_network,
        networks.StochasticMeanHead(),
    ])
    # The multiplexer concatenates the (maybe transformed) observations/actions.
    multiplexer = networks.CriticMultiplexer(
        action_network=networks.ClipToSpec(action_spec))
    critic_network = snt.Sequential([
        multiplexer,
        networks.LayerNormMLP(critic_layer_sizes, activate_final=True),
        networks.NearZeroInitializedLinear(1),
    ])

    return {
        'policy': policy_network,
        'critic': critic_network,
        'observation': observation_network,
        'evaluator': evaluator_network,
    }
Esempio n. 4
0
def make_acme_agent(environment_spec,
                    residual_spec,
                    obs_network_type,
                    crop_frames,
                    full_image_size,
                    crop_margin_size,
                    late_fusion,
                    binary_grip_action=False,
                    input_type=None,
                    counter=None,
                    logdir=None,
                    agent_logger=None):
    """Initialize acme agent based on residual spec and agent flags."""
    # TODO(minttu): Is environment_spec needed or could we use residual_spec?
    del logdir  # Setting logdir for the learner ckpts not currently supported.
    obs_network = None
    if obs_network_type is not None:
        obs_network = agents.ObservationNet(network_type=obs_network_type,
                                            input_type=input_type,
                                            add_linear_layer=False,
                                            crop_frames=crop_frames,
                                            full_image_size=full_image_size,
                                            crop_margin_size=crop_margin_size,
                                            late_fusion=late_fusion)

    eval_policy = None
    if FLAGS.agent == 'MPO':
        agent_networks = networks.make_mpo_networks(
            environment_spec.actions,
            policy_init_std=FLAGS.policy_init_std,
            obs_network=obs_network)

        rl_agent = mpo.MPO(
            environment_spec=residual_spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_rl),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            counter=counter,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )
    elif FLAGS.agent == 'DMPO':
        agent_networks = networks.make_dmpo_networks(
            environment_spec.actions,
            policy_layer_sizes=FLAGS.rl_policy_layer_sizes,
            critic_layer_sizes=FLAGS.rl_critic_layer_sizes,
            vmin=FLAGS.critic_vmin,
            vmax=FLAGS.critic_vmax,
            num_atoms=FLAGS.critic_num_atoms,
            policy_init_std=FLAGS.policy_init_std,
            binary_grip_action=binary_grip_action,
            obs_network=obs_network)

        # spec = residual_spec if obs_network is None else environment_spec
        spec = residual_spec
        rl_agent = dmpo.DistributionalMPO(
            environment_spec=spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            counter=counter,
            # logdir=logdir,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )
        # Learned policy without exploration.
        eval_policy = (tf.function(
            snt.Sequential([
                tf_utils.to_sonnet_module(agent_networks['observation']),
                agent_networks['policy'],
                tf_networks.StochasticMeanHead()
            ])))
    elif FLAGS.agent == 'D4PG':
        agent_networks = networks.make_d4pg_networks(
            residual_spec.actions,
            vmin=FLAGS.critic_vmin,
            vmax=FLAGS.critic_vmax,
            num_atoms=FLAGS.critic_num_atoms,
            policy_weights_init_scale=FLAGS.policy_weights_init_scale,
            obs_network=obs_network)

        # TODO(minttu): downscale action space to [-1, 1] to match clipped gaussian.
        rl_agent = d4pg.D4PG(
            environment_spec=residual_spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            observation_network=agent_networks['observation'],
            discount=FLAGS.discount,
            batch_size=FLAGS.rl_batch_size,
            min_replay_size=FLAGS.min_replay_size,
            max_replay_size=FLAGS.max_replay_size,
            policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr),
            critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr),
            sigma=FLAGS.policy_init_std,
            counter=counter,
            logger=agent_logger,
            checkpoint=FLAGS.write_acme_checkpoints,
        )

        # Learned policy without exploration.
        eval_policy = tf.function(
            snt.Sequential([
                tf_utils.to_sonnet_module(agent_networks['observation']),
                agent_networks['policy']
            ]))

    else:
        raise NotImplementedError('Supported agents: MPO, DMPO, D4PG.')
    return rl_agent, eval_policy
Esempio n. 5
0
    def __init__(self,
                 policy_network: snt.RNNCore,
                 critic_network: networks.CriticDeepRNN,
                 target_policy_network: snt.RNNCore,
                 target_critic_network: networks.CriticDeepRNN,
                 dataset: tf.data.Dataset,
                 accelerator_strategy: Optional[tf.distribute.Strategy] = None,
                 behavior_network: Optional[snt.Module] = None,
                 cwp_network: Optional[snt.Module] = None,
                 policy_optimizer: Optional[snt.Optimizer] = None,
                 critic_optimizer: Optional[snt.Optimizer] = None,
                 discount: float = 0.99,
                 target_update_period: int = 100,
                 num_action_samples_td_learning: int = 1,
                 num_action_samples_policy_weight: int = 4,
                 baseline_reduce_function: str = 'mean',
                 clipping: bool = True,
                 policy_improvement_modes: str = 'exp',
                 ratio_upper_bound: float = 20.,
                 beta: float = 1.0,
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None,
                 checkpoint: bool = False):
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_reverb_dataset` documentation).
      accelerator_strategy: the strategy used to distribute computation,
        whether on a single, or multiple, GPU or TPU; as supported by
        tf.distribute.
      behavior_network: The network to snapshot under `policy` name. If None,
        snapshots `policy_network` instead.
      cwp_network: CWP network to snapshot: samples actions
        from the policy and weighs them with the critic, then returns the action
        by sampling from the softmax distribution using critic values as logits.
        Used only for snapshotting, not training.
      policy_optimizer: the optimizer to be applied to the policy loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      num_action_samples_td_learning: number of action samples to use to
        estimate expected value of the critic loss w.r.t. stochastic policy.
      num_action_samples_policy_weight: number of action samples to use to
        estimate the advantage function for the CRR weighting of the policy
        loss.
      baseline_reduce_function: one of 'mean', 'max', 'min'. Way of aggregating
        values from `num_action_samples` estimates of the value function.
      clipping: whether to clip gradients by global norm.
      policy_improvement_modes: one of 'exp', 'binary', 'all'. CRR mode which
        determines how the advantage function is processed before being
        multiplied by the policy loss.
      ratio_upper_bound: if policy_improvement_modes is 'exp', determines
        the upper bound of the weight (i.e. the weight is
          min(exp(advantage / beta), upper_bound)
        ).
      beta: if policy_improvement_modes is 'exp', determines the beta (see
        above).
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        if accelerator_strategy is None:
            accelerator_strategy = snt.distribute.Replicator()
        self._accelerator_strategy = accelerator_strategy
        self._policy_improvement_modes = policy_improvement_modes
        self._ratio_upper_bound = ratio_upper_bound
        self._num_action_samples_td_learning = num_action_samples_td_learning
        self._num_action_samples_policy_weight = num_action_samples_policy_weight
        self._baseline_reduce_function = baseline_reduce_function
        self._beta = beta

        # When running on TPUs we have to know the amount of memory required (and
        # thus the sequence length) at the graph compilation stage. At the moment,
        # the only way to get it is to sample from the dataset, since the dataset
        # does not have any metadata, see b/160672927 to track this upcoming
        # feature.
        sample = next(dataset.as_numpy_iterator())
        self._sequence_length = sample.action.shape[1]

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)
        self._discount = discount
        self._clipping = clipping

        self._target_update_period = target_update_period

        with self._accelerator_strategy.scope():
            # Necessary to track when to update target networks.
            self._num_steps = tf.Variable(0, dtype=tf.int32)

            # (Maybe) distributing the dataset across multiple accelerators.
            distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset(
                dataset)
            self._iterator = iter(distributed_dataset)

            # Create the optimizers.
            self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(
                1e-4)
            self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(
                1e-4)

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Expose the variables.
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': self._target_policy_network.variables,
        }

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None
        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'num_steps': self._num_steps,
                },
                time_delta_minutes=30.)

            raw_policy = snt.DeepRNN(
                [policy_network,
                 networks.StochasticSamplingHead()])
            critic_mean = networks.CriticDeepRNN(
                [critic_network, networks.StochasticMeanHead()])
            objects_to_save = {
                'raw_policy': raw_policy,
                'critic': critic_mean,
            }
            if behavior_network is not None:
                objects_to_save['policy'] = behavior_network
            if cwp_network is not None:
                objects_to_save['cwp_policy'] = cwp_network
            self._snapshotter = tf2_savers.Snapshotter(
                objects_to_save=objects_to_save, time_delta_minutes=30)
        # Timestamp to keep track of the wall time.
        self._walltime_timestamp = time.time()
Esempio n. 6
0
  def __init__(
      self,
      policy_network: snt.Module,
      critic_network: snt.Module,
      target_policy_network: snt.Module,
      target_critic_network: snt.Module,
      discount: float,
      target_update_period: int,
      dataset_iterator: Iterator[reverb.ReplaySample],
      observation_network: types.TensorTransformation = lambda x: x,
      target_observation_network: types.TensorTransformation = lambda x: x,
      policy_optimizer: snt.Optimizer = None,
      critic_optimizer: snt.Optimizer = None,
      clipping: bool = True,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      checkpoint: bool = True,
  ):
    """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset_iterator: dataset to learn from, whether fixed or from a replay
        buffer (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

    # Store online and target networks.
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_policy_network = target_policy_network
    self._target_critic_network = target_critic_network

    # Make sure observation networks are snt.Module's so they have variables.
    self._observation_network = tf2_utils.to_sonnet_module(observation_network)
    self._target_observation_network = tf2_utils.to_sonnet_module(
        target_observation_network)

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')

    # Other learner parameters.
    self._discount = discount
    self._clipping = clipping

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    self._iterator = dataset_iterator

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
    self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

    # Expose the variables.
    policy_network_to_expose = snt.Sequential(
        [self._target_observation_network, self._target_policy_network])
    self._variables = {
        'critic': self._target_critic_network.variables,
        'policy': policy_network_to_expose.variables,
    }

    # Create a checkpointer and snapshotter objects.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          subdirectory='d4pg_learner',
          objects_to_save={
              'counter': self._counter,
              'policy': self._policy_network,
              'critic': self._critic_network,
              'observation': self._observation_network,
              'target_policy': self._target_policy_network,
              'target_critic': self._target_critic_network,
              'target_observation': self._target_observation_network,
              'policy_optimizer': self._policy_optimizer,
              'critic_optimizer': self._critic_optimizer,
              'num_steps': self._num_steps,
          })
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'policy': self._policy_network,
              'critic': critic_mean,
          })

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None
Esempio n. 7
0
    def __init__(self,
                 n_classes=None,
                 last_activation=None,
                 fc_layer_sizes=(),
                 weight_decay=5e-4,
                 bn_axis=3,
                 batch_norm_decay=0.1,
                 init_scheme='v1'):
        super(Resnet18Narrow32, self).__init__(name='')

        if init_scheme == 'v1':
            print('Using v1 weight init')
            conv2d_init = v1_conv2d_init
            # Bias is not used in conv layers.
            linear_init = v1_linear_init
            linear_bias_init = v1_linear_bias_init
        else:
            print('Using v2 weight init')
            conv2d_init = keras.initializers.VarianceScaling(
                scale=2.0, mode='fan_out', distribution='untruncated_normal')
            linear_init = torch_linear_init
            linear_bias_init = torch_linear_bias_init

        # Why is this separate instead of padding='same' in tfl.Conv2D?
        self.zero_pad = tfl.ZeroPadding2D(padding=(3, 3),
                                          input_shape=(32, 32, 3),
                                          name='conv1_pad')
        self.conv1 = tfl.Conv2D(
            64, (7, 7),
            strides=(2, 2),
            padding='valid',
            kernel_initializer=conv2d_init,
            kernel_regularizer=keras.regularizers.l2(weight_decay),
            use_bias=False,
            name='conv1')
        self.bn1 = tfl.BatchNormalization(axis=bn_axis,
                                          name='bn_conv1',
                                          momentum=batch_norm_decay,
                                          epsilon=BATCH_NORM_EPSILON)
        self.zero_pad2 = tfl.ZeroPadding2D(padding=(1, 1), name='max_pool_pad')
        self.max_pool = tfl.MaxPooling2D(pool_size=(3, 3),
                                         strides=(2, 2),
                                         padding='valid')

        self.resblock1 = Resnet18Block(kernel_size=3,
                                       input_planes=64,
                                       output_planes=32,
                                       stage=2,
                                       strides=(1, 1),
                                       weight_decay=weight_decay,
                                       batch_norm_decay=batch_norm_decay,
                                       conv2d_init=conv2d_init)

        self.resblock2 = Resnet18Block(kernel_size=3,
                                       input_planes=32,
                                       output_planes=64,
                                       stage=3,
                                       strides=(2, 2),
                                       weight_decay=weight_decay,
                                       batch_norm_decay=batch_norm_decay,
                                       conv2d_init=conv2d_init)

        self.resblock3 = Resnet18Block(kernel_size=3,
                                       input_planes=64,
                                       output_planes=128,
                                       stage=4,
                                       strides=(2, 2),
                                       weight_decay=weight_decay,
                                       batch_norm_decay=batch_norm_decay,
                                       conv2d_init=conv2d_init)

        self.resblock4 = Resnet18Block(kernel_size=3,
                                       input_planes=128,
                                       output_planes=256,
                                       stage=4,
                                       strides=(2, 2),
                                       weight_decay=weight_decay,
                                       batch_norm_decay=batch_norm_decay,
                                       conv2d_init=conv2d_init)

        self.pool = tfl.GlobalAveragePooling2D(name='avg_pool')
        self.bn2 = tfl.BatchNormalization(axis=-1,
                                          name='bn_conv2',
                                          momentum=batch_norm_decay,
                                          epsilon=BATCH_NORM_EPSILON)
        self.fcs = []
        if FLAGS.layer_norm_policy:
            self.linear = snt.Sequential([
                networks.LayerNormMLP(fc_layer_sizes),
                networks.MultivariateNormalDiagHead(n_classes),
                networks.StochasticMeanHead()
            ])
        else:
            for size in fc_layer_sizes:
                self.fcs.append(
                    tfl.Dense(
                        size,
                        activation=tf.nn.relu,
                        kernel_initializer=linear_init,
                        bias_initializer=linear_bias_init,
                        kernel_regularizer=keras.regularizers.l2(weight_decay),
                        bias_regularizer=keras.regularizers.l2(weight_decay)))
            if n_classes is not None:
                self.linear = tfl.Dense(
                    n_classes,
                    activation=last_activation,
                    kernel_initializer=linear_init,
                    bias_initializer=linear_bias_init,
                    kernel_regularizer=keras.regularizers.l2(weight_decay),
                    bias_regularizer=keras.regularizers.l2(weight_decay),
                    name='fc%d' % n_classes)
        self.n_classes = n_classes
        if n_classes is not None:
            self.log_std = tf.Variable(tf.zeros(n_classes),
                                       trainable=True,
                                       name='log_std')
        self.first_forward_pass = FLAGS.data_smaller
Esempio n. 8
0
    def __init__(self,
                 policy_network: snt.Module,
                 critic_network: snt.Module,
                 target_policy_network: snt.Module,
                 target_critic_network: snt.Module,
                 discount: float,
                 target_update_period: int,
                 dataset: tf.data.Dataset,
                 observation_network: types.TensorTransformation = lambda x: x,
                 target_observation_network: types.
                 TensorTransformation = lambda x: x,
                 policy_optimizer: snt.Optimizer = None,
                 critic_optimizer: snt.Optimizer = None,
                 clipping: bool = True,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True,
                 specified_path: str = None):
        # print('\033[94m I am sub_virtual acme d4pg learning\033[0m')
        """Initializes the learner.

    Args:
      policy_network: the online (optimized) policy.
      critic_network: the online critic.
      target_policy_network: the target policy (which lags behind the online
        policy).
      target_critic_network: the target critic.
      discount: discount to use for TD updates.
      target_update_period: number of learner steps to perform before updating
        the target networks.
      dataset: dataset to learn from, whether fixed or from a replay buffer
        (see `acme.datasets.reverb.make_dataset` documentation).
      observation_network: an optional online network to process observations
        before the policy and the critic.
      target_observation_network: the target observation network.
      policy_optimizer: the optimizer to be applied to the DPG (policy) loss.
      critic_optimizer: the optimizer to be applied to the distributional
        Bellman loss.
      clipping: whether to clip gradients by global norm.
      counter: counter object used to keep track of steps.
      logger: logger object to be used by learner.
      checkpoint: boolean indicating whether to checkpoint the learner.
    """

        # Store online and target networks.
        self._policy_network = policy_network
        self._critic_network = critic_network
        self._target_policy_network = target_policy_network
        self._target_critic_network = target_critic_network

        # Make sure observation networks are snt.Module's so they have variables.
        self._observation_network = tf2_utils.to_sonnet_module(
            observation_network)
        self._target_observation_network = tf2_utils.to_sonnet_module(
            target_observation_network)

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger('learner')

        # Other learner parameters.
        self._discount = discount
        self._clipping = clipping

        # Necessary to track when to update target networks.
        self._num_steps = tf.Variable(0, dtype=tf.int32)
        self._target_update_period = target_update_period

        # Batch dataset and create iterator.
        # TODO(b/155086959): Fix type stubs and remove.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        # Create optimizers if they aren't given.
        self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4)
        self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4)

        # Expose the variables.
        policy_network_to_expose = snt.Sequential(
            [self._target_observation_network, self._target_policy_network])
        self._variables = {
            'critic': self._target_critic_network.variables,
            'policy': policy_network_to_expose.variables,
        }

        # Create a checkpointer and snapshotter objects.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                subdirectory='d4pg_learner',
                objects_to_save={
                    'counter': self._counter,
                    'policy': self._policy_network,
                    'critic': self._critic_network,
                    'observation': self._observation_network,
                    'target_policy': self._target_policy_network,
                    'target_critic': self._target_critic_network,
                    'target_observation': self._target_observation_network,
                    'policy_optimizer': self._policy_optimizer,
                    'critic_optimizer': self._critic_optimizer,
                    'num_steps': self._num_steps,
                })
            # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/efc0f104-d4a6-11eb-9d04-04d4c40103a8/checkpoints/d4pg_learner/checkpoint/ckpt-1')
            # self._checkpointer._checkpoint.restore('/home/argsubt/acme/f397d4d6-edf2-11eb-a739-04d4c40103a8/checkpoints/d4pg_learner/ckpt-1')
            # print('\033[92mload checkpoints~\033[0m')
            # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/4346ec84-ee10-11eb-8185-04d4c40103a8/checkpoints/d4pg_learner/ckpt-532')
            self.specified_path = specified_path
            if self.specified_path != None:
                self._checkpointer._checkpoint.restore(self.specified_path)
                print('\033[92mspecified_path: ', str(self.specified_path),
                      '\033[0m')
            critic_mean = snt.Sequential(
                [self._critic_network,
                 acme_nets.StochasticMeanHead()])
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'policy': self._policy_network,
                'critic': critic_mean,
            })

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
Esempio n. 9
0
  def __init__(self,
               policy_network: snt.Module,
               critic_network: snt.Module,
               target_critic_network: snt.Module,
               discount: float,
               target_update_period: int,
               dataset: tf.data.Dataset,
               critic_optimizer: snt.Optimizer = None,
               critic_lr: float = 1e-4,
               checkpoint_interval_minutes: int = 10.0,
               clipping: bool = True,
               counter: counting.Counter = None,
               logger: loggers.Logger = None,
               checkpoint: bool = True,
               init_observations: Any = None,
               distributional: bool = True,
               vmin: Optional[float] = None,
               vmax: Optional[float] = None,
               ):
    self._policy_network = policy_network
    self._critic_network = critic_network
    self._target_critic_network = target_critic_network
    self._discount = discount
    self._clipping = clipping
    self._init_observations = init_observations
    self._distributional = distributional
    self._vmin = vmin
    self._vmax = vmax

    if (self._distributional and
        (self._vmin is not None or self._vmax is not None)):
      logging.warning('vmin and vmax arguments to FQELearner are ignored when '
                      'distributional=True. They should be provided when '
                      'creating the critic network.')

    # General learner book-keeping and loggers.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.)

    # Necessary to track when to update target networks.
    self._num_steps = tf.Variable(0, dtype=tf.int32)
    self._target_update_period = target_update_period

    # Batch dataset and create iterator.
    self._iterator = iter(dataset)

    # Create optimizers if they aren't given.
    self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(critic_lr)

    # Expose the variables.
    self._variables = {
        'critic': self._target_critic_network.variables,
    }
    if distributional:
      critic_mean = snt.Sequential(
          [self._critic_network, acme_nets.StochasticMeanHead()])
    else:
      # We remove trailing dimensions to keep same output dimmension
      # as existing FQE based on D4PG. i.e.: (batch_size,).
      critic_mean = snt.Sequential(
          [self._critic_network, lambda t: tf.squeeze(t, -1)])
    self._critic_mean = critic_mean

    # Create a checkpointer object.
    self._checkpointer = None
    self._snapshotter = None

    if checkpoint:
      self._checkpointer = tf2_savers.Checkpointer(
          objects_to_save=self.state,
          time_delta_minutes=checkpoint_interval_minutes,
          checkpoint_ttl_seconds=_CHECKPOINT_TTL)
      self._snapshotter = tf2_savers.Snapshotter(
          objects_to_save={
              'critic': critic_mean,
          },
          time_delta_minutes=60.)