def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, ): """The evaluation process.""" action_spec = self._environment_spec.actions observation_spec = self._environment_spec.observations # Create environment and target networks to act with. environment = self._environment_factory(True) agent_networks = self._network_factory(action_spec) # Make sure observation network is defined. observation_network = agent_networks.get('observation', tf.identity) # Create a stochastic behavior policy. evaluator_network = snt.Sequential([ observation_network, agent_networks['policy'], networks.StochasticMeanHead(), ]) # Ensure network variables are created. tf2_utils.create_variables(evaluator_network, [observation_spec]) policy_variables = {'policy': evaluator_network.variables} # Create the variable client responsible for keeping the actor up-to-date. variable_client = tf2_variable_utils.VariableClient( variable_source, policy_variables, update_period=self._variable_update_period) # Make sure not to evaluate a random actor by assigning variables before # running the environment loop. variable_client.update_and_wait() # Create the agent. evaluator = actors.FeedForwardActor( policy_network=evaluator_network, variable_client=variable_client) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') logger = loggers.make_default_logger( 'evaluator', time_delta=self._log_every, steps_key='evaluator_steps') observers = self._make_observers() if self._make_observers else () # Create the run loop and return it. return acme.EnvironmentLoop( environment, evaluator, counter, logger, observers=observers)
def evaluator( self, variable_source: acme.VariableSource, counter: counting.Counter, ): """The evaluation process.""" # Create environment and target networks to act with. environment = self._environment_factory(True) agent_networks = self._network_factory(self._environment_spec) # Create a stochastic behavior policy. evaluator_network = snt.Sequential([ agent_networks['observation'], agent_networks['policy'], networks.StochasticMeanHead(), ]) # Create the variable client responsible for keeping the actor up-to-date. variable_client = tf2_variable_utils.VariableClient( variable_source, variables={'policy': evaluator_network.variables}, update_period=1000) # Make sure not to evaluate a random actor by assigning variables before # running the environment loop. variable_client.update_and_wait() # Create the agent. evaluator = actors.FeedForwardActor(policy_network=evaluator_network, variable_client=variable_client) # Create logger and counter. counter = counting.Counter(counter, 'evaluator') logger = loggers.make_default_logger('evaluator', time_delta=self._log_every) # Create the run loop and return it. return acme.EnvironmentLoop(environment, evaluator, counter, logger)
def make_networks( action_spec: specs.BoundedArray, policy_layer_sizes: Sequence[int] = (50, 50), critic_layer_sizes: Sequence[int] = (50, 50), ): """Creates networks used by the agent.""" num_dimensions = np.prod(action_spec.shape, dtype=int) observation_network = tf2_utils.batch_concat policy_network = snt.Sequential([ networks.LayerNormMLP(policy_layer_sizes, activate_final=True), networks.MultivariateNormalDiagHead(num_dimensions, tanh_mean=True, init_scale=0.3, fixed_scale=True, use_tfd_independent=False) ]) evaluator_network = snt.Sequential([ observation_network, policy_network, networks.StochasticMeanHead(), ]) # The multiplexer concatenates the (maybe transformed) observations/actions. multiplexer = networks.CriticMultiplexer( action_network=networks.ClipToSpec(action_spec)) critic_network = snt.Sequential([ multiplexer, networks.LayerNormMLP(critic_layer_sizes, activate_final=True), networks.NearZeroInitializedLinear(1), ]) return { 'policy': policy_network, 'critic': critic_network, 'observation': observation_network, 'evaluator': evaluator_network, }
def make_acme_agent(environment_spec, residual_spec, obs_network_type, crop_frames, full_image_size, crop_margin_size, late_fusion, binary_grip_action=False, input_type=None, counter=None, logdir=None, agent_logger=None): """Initialize acme agent based on residual spec and agent flags.""" # TODO(minttu): Is environment_spec needed or could we use residual_spec? del logdir # Setting logdir for the learner ckpts not currently supported. obs_network = None if obs_network_type is not None: obs_network = agents.ObservationNet(network_type=obs_network_type, input_type=input_type, add_linear_layer=False, crop_frames=crop_frames, full_image_size=full_image_size, crop_margin_size=crop_margin_size, late_fusion=late_fusion) eval_policy = None if FLAGS.agent == 'MPO': agent_networks = networks.make_mpo_networks( environment_spec.actions, policy_init_std=FLAGS.policy_init_std, obs_network=obs_network) rl_agent = mpo.MPO( environment_spec=residual_spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], observation_network=agent_networks['observation'], discount=FLAGS.discount, batch_size=FLAGS.rl_batch_size, min_replay_size=FLAGS.min_replay_size, max_replay_size=FLAGS.max_replay_size, policy_optimizer=snt.optimizers.Adam(FLAGS.policy_rl), critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr), counter=counter, logger=agent_logger, checkpoint=FLAGS.write_acme_checkpoints, ) elif FLAGS.agent == 'DMPO': agent_networks = networks.make_dmpo_networks( environment_spec.actions, policy_layer_sizes=FLAGS.rl_policy_layer_sizes, critic_layer_sizes=FLAGS.rl_critic_layer_sizes, vmin=FLAGS.critic_vmin, vmax=FLAGS.critic_vmax, num_atoms=FLAGS.critic_num_atoms, policy_init_std=FLAGS.policy_init_std, binary_grip_action=binary_grip_action, obs_network=obs_network) # spec = residual_spec if obs_network is None else environment_spec spec = residual_spec rl_agent = dmpo.DistributionalMPO( environment_spec=spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], observation_network=agent_networks['observation'], discount=FLAGS.discount, batch_size=FLAGS.rl_batch_size, min_replay_size=FLAGS.min_replay_size, max_replay_size=FLAGS.max_replay_size, policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr), critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr), counter=counter, # logdir=logdir, logger=agent_logger, checkpoint=FLAGS.write_acme_checkpoints, ) # Learned policy without exploration. eval_policy = (tf.function( snt.Sequential([ tf_utils.to_sonnet_module(agent_networks['observation']), agent_networks['policy'], tf_networks.StochasticMeanHead() ]))) elif FLAGS.agent == 'D4PG': agent_networks = networks.make_d4pg_networks( residual_spec.actions, vmin=FLAGS.critic_vmin, vmax=FLAGS.critic_vmax, num_atoms=FLAGS.critic_num_atoms, policy_weights_init_scale=FLAGS.policy_weights_init_scale, obs_network=obs_network) # TODO(minttu): downscale action space to [-1, 1] to match clipped gaussian. rl_agent = d4pg.D4PG( environment_spec=residual_spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], observation_network=agent_networks['observation'], discount=FLAGS.discount, batch_size=FLAGS.rl_batch_size, min_replay_size=FLAGS.min_replay_size, max_replay_size=FLAGS.max_replay_size, policy_optimizer=snt.optimizers.Adam(FLAGS.policy_lr), critic_optimizer=snt.optimizers.Adam(FLAGS.critic_lr), sigma=FLAGS.policy_init_std, counter=counter, logger=agent_logger, checkpoint=FLAGS.write_acme_checkpoints, ) # Learned policy without exploration. eval_policy = tf.function( snt.Sequential([ tf_utils.to_sonnet_module(agent_networks['observation']), agent_networks['policy'] ])) else: raise NotImplementedError('Supported agents: MPO, DMPO, D4PG.') return rl_agent, eval_policy
def __init__(self, policy_network: snt.RNNCore, critic_network: networks.CriticDeepRNN, target_policy_network: snt.RNNCore, target_critic_network: networks.CriticDeepRNN, dataset: tf.data.Dataset, accelerator_strategy: Optional[tf.distribute.Strategy] = None, behavior_network: Optional[snt.Module] = None, cwp_network: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, discount: float = 0.99, target_update_period: int = 100, num_action_samples_td_learning: int = 1, num_action_samples_policy_weight: int = 4, baseline_reduce_function: str = 'mean', clipping: bool = True, policy_improvement_modes: str = 'exp', ratio_upper_bound: float = 20., beta: float = 1.0, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = False): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation). accelerator_strategy: the strategy used to distribute computation, whether on a single, or multiple, GPU or TPU; as supported by tf.distribute. behavior_network: The network to snapshot under `policy` name. If None, snapshots `policy_network` instead. cwp_network: CWP network to snapshot: samples actions from the policy and weighs them with the critic, then returns the action by sampling from the softmax distribution using critic values as logits. Used only for snapshotting, not training. policy_optimizer: the optimizer to be applied to the policy loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. num_action_samples_td_learning: number of action samples to use to estimate expected value of the critic loss w.r.t. stochastic policy. num_action_samples_policy_weight: number of action samples to use to estimate the advantage function for the CRR weighting of the policy loss. baseline_reduce_function: one of 'mean', 'max', 'min'. Way of aggregating values from `num_action_samples` estimates of the value function. clipping: whether to clip gradients by global norm. policy_improvement_modes: one of 'exp', 'binary', 'all'. CRR mode which determines how the advantage function is processed before being multiplied by the policy loss. ratio_upper_bound: if policy_improvement_modes is 'exp', determines the upper bound of the weight (i.e. the weight is min(exp(advantage / beta), upper_bound) ). beta: if policy_improvement_modes is 'exp', determines the beta (see above). counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ if accelerator_strategy is None: accelerator_strategy = snt.distribute.Replicator() self._accelerator_strategy = accelerator_strategy self._policy_improvement_modes = policy_improvement_modes self._ratio_upper_bound = ratio_upper_bound self._num_action_samples_td_learning = num_action_samples_td_learning self._num_action_samples_policy_weight = num_action_samples_policy_weight self._baseline_reduce_function = baseline_reduce_function self._beta = beta # When running on TPUs we have to know the amount of memory required (and # thus the sequence length) at the graph compilation stage. At the moment, # the only way to get it is to sample from the dataset, since the dataset # does not have any metadata, see b/160672927 to track this upcoming # feature. sample = next(dataset.as_numpy_iterator()) self._sequence_length = sample.action.shape[1] self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self._discount = discount self._clipping = clipping self._target_update_period = target_update_period with self._accelerator_strategy.scope(): # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) # (Maybe) distributing the dataset across multiple accelerators. distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset( dataset) self._iterator = iter(distributed_dataset) # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam( 1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam( 1e-4) # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Expose the variables. self._variables = { 'critic': self._target_critic_network.variables, 'policy': self._target_policy_network.variables, } # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }, time_delta_minutes=30.) raw_policy = snt.DeepRNN( [policy_network, networks.StochasticSamplingHead()]) critic_mean = networks.CriticDeepRNN( [critic_network, networks.StochasticMeanHead()]) objects_to_save = { 'raw_policy': raw_policy, 'critic': critic_mean, } if behavior_network is not None: objects_to_save['policy'] = behavior_network if cwp_network is not None: objects_to_save['cwp_policy'] = cwp_network self._snapshotter = tf2_savers.Snapshotter( objects_to_save=objects_to_save, time_delta_minutes=30) # Timestamp to keep track of the wall time. self._walltime_timestamp = time.time()
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset_iterator: Iterator[reverb.ReplaySample], observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types.TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset_iterator: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module(observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. self._iterator = dataset_iterator # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='d4pg_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }) critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': self._policy_network, 'critic': critic_mean, }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, n_classes=None, last_activation=None, fc_layer_sizes=(), weight_decay=5e-4, bn_axis=3, batch_norm_decay=0.1, init_scheme='v1'): super(Resnet18Narrow32, self).__init__(name='') if init_scheme == 'v1': print('Using v1 weight init') conv2d_init = v1_conv2d_init # Bias is not used in conv layers. linear_init = v1_linear_init linear_bias_init = v1_linear_bias_init else: print('Using v2 weight init') conv2d_init = keras.initializers.VarianceScaling( scale=2.0, mode='fan_out', distribution='untruncated_normal') linear_init = torch_linear_init linear_bias_init = torch_linear_bias_init # Why is this separate instead of padding='same' in tfl.Conv2D? self.zero_pad = tfl.ZeroPadding2D(padding=(3, 3), input_shape=(32, 32, 3), name='conv1_pad') self.conv1 = tfl.Conv2D( 64, (7, 7), strides=(2, 2), padding='valid', kernel_initializer=conv2d_init, kernel_regularizer=keras.regularizers.l2(weight_decay), use_bias=False, name='conv1') self.bn1 = tfl.BatchNormalization(axis=bn_axis, name='bn_conv1', momentum=batch_norm_decay, epsilon=BATCH_NORM_EPSILON) self.zero_pad2 = tfl.ZeroPadding2D(padding=(1, 1), name='max_pool_pad') self.max_pool = tfl.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='valid') self.resblock1 = Resnet18Block(kernel_size=3, input_planes=64, output_planes=32, stage=2, strides=(1, 1), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.resblock2 = Resnet18Block(kernel_size=3, input_planes=32, output_planes=64, stage=3, strides=(2, 2), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.resblock3 = Resnet18Block(kernel_size=3, input_planes=64, output_planes=128, stage=4, strides=(2, 2), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.resblock4 = Resnet18Block(kernel_size=3, input_planes=128, output_planes=256, stage=4, strides=(2, 2), weight_decay=weight_decay, batch_norm_decay=batch_norm_decay, conv2d_init=conv2d_init) self.pool = tfl.GlobalAveragePooling2D(name='avg_pool') self.bn2 = tfl.BatchNormalization(axis=-1, name='bn_conv2', momentum=batch_norm_decay, epsilon=BATCH_NORM_EPSILON) self.fcs = [] if FLAGS.layer_norm_policy: self.linear = snt.Sequential([ networks.LayerNormMLP(fc_layer_sizes), networks.MultivariateNormalDiagHead(n_classes), networks.StochasticMeanHead() ]) else: for size in fc_layer_sizes: self.fcs.append( tfl.Dense( size, activation=tf.nn.relu, kernel_initializer=linear_init, bias_initializer=linear_bias_init, kernel_regularizer=keras.regularizers.l2(weight_decay), bias_regularizer=keras.regularizers.l2(weight_decay))) if n_classes is not None: self.linear = tfl.Dense( n_classes, activation=last_activation, kernel_initializer=linear_init, bias_initializer=linear_bias_init, kernel_regularizer=keras.regularizers.l2(weight_decay), bias_regularizer=keras.regularizers.l2(weight_decay), name='fc%d' % n_classes) self.n_classes = n_classes if n_classes is not None: self.log_std = tf.Variable(tf.zeros(n_classes), trainable=True, name='log_std') self.first_forward_pass = FLAGS.data_smaller
def __init__(self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types. TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, specified_path: str = None): # print('\033[94m I am sub_virtual acme d4pg learning\033[0m') """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='d4pg_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }) # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/efc0f104-d4a6-11eb-9d04-04d4c40103a8/checkpoints/d4pg_learner/checkpoint/ckpt-1') # self._checkpointer._checkpoint.restore('/home/argsubt/acme/f397d4d6-edf2-11eb-a739-04d4c40103a8/checkpoints/d4pg_learner/ckpt-1') # print('\033[92mload checkpoints~\033[0m') # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/4346ec84-ee10-11eb-8185-04d4c40103a8/checkpoints/d4pg_learner/ckpt-532') self.specified_path = specified_path if self.specified_path != None: self._checkpointer._checkpoint.restore(self.specified_path) print('\033[92mspecified_path: ', str(self.specified_path), '\033[0m') critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'policy': self._policy_network, 'critic': critic_mean, }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, policy_network: snt.Module, critic_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset: tf.data.Dataset, critic_optimizer: snt.Optimizer = None, critic_lr: float = 1e-4, checkpoint_interval_minutes: int = 10.0, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, init_observations: Any = None, distributional: bool = True, vmin: Optional[float] = None, vmax: Optional[float] = None, ): self._policy_network = policy_network self._critic_network = critic_network self._target_critic_network = target_critic_network self._discount = discount self._clipping = clipping self._init_observations = init_observations self._distributional = distributional self._vmin = vmin self._vmax = vmax if (self._distributional and (self._vmin is not None or self._vmax is not None)): logging.warning('vmin and vmax arguments to FQELearner are ignored when ' 'distributional=True. They should be provided when ' 'creating the critic network.') # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. self._iterator = iter(dataset) # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(critic_lr) # Expose the variables. self._variables = { 'critic': self._target_critic_network.variables, } if distributional: critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) else: # We remove trailing dimensions to keep same output dimmension # as existing FQE based on D4PG. i.e.: (batch_size,). critic_mean = snt.Sequential( [self._critic_network, lambda t: tf.squeeze(t, -1)]) self._critic_mean = critic_mean # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save=self.state, time_delta_minutes=checkpoint_interval_minutes, checkpoint_ttl_seconds=_CHECKPOINT_TTL) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'critic': critic_mean, }, time_delta_minutes=60.)