def __init__(self, network: snt.Module, learning_rate: float, dataset: tf.data.Dataset, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True): """Initializes the learner. Args: network: the BC network (the one being optimized) learning_rate: learning rate for the cross-entropy update. dataset: dataset to learn from. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. """ self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # TODO(b/155086959): Fix type stubs and remove. self._network = network self._optimizer = snt.optimizers.Adam(learning_rate) self._variables: List[List[tf.Tensor]] = [network.trainable_variables] self._num_steps = tf.Variable(0, dtype=tf.int32) # Create a snapshotter object. if checkpoint: self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) else: self._snapshotter = None
def __init__( self, network: discrete_networks.DiscreteFilteredQNetwork, discount: float, importance_sampling_exponent: float, learning_rate: float, target_update_period: int, dataset: tf.data.Dataset, huber_loss_parameter: float = 1., replay_client: reverb.TFClient = None, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = False, ): """Initializes the learner. Args: network: BCQ network discount: discount to use for TD updates. importance_sampling_exponent: power to which importance weights are raised before normalizing. learning_rate: learning rate for the q-network update. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). huber_loss_parameter: Quadratic-linear boundary for Huber loss. replay_client: client to replay to allow for updating priorities. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. """ # Internalise agent components (replay buffer, networks, optimizer). # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._network = network self._q_network = network.q_network self._target_q_network = copy.deepcopy(network.q_network) self._optimizer = snt.optimizers.Adam(learning_rate) self._replay_client = replay_client # Internalise the hyperparameters. self._discount = discount self._target_update_period = target_update_period self._importance_sampling_exponent = importance_sampling_exponent self._huber_loss_parameter = huber_loss_parameter # Learner state. self._variables = [self._network.trainable_variables] self._num_steps = tf.Variable(0, dtype=tf.int32) # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner', save_data=False) # Create a snapshotter object. if checkpoint: self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) else: self._snapshotter = None
def __init__(self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0): extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, sequence_length=sequence_length) # Combine with demonstration dataset. transition = functools.partial(_sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset_iterator: Iterator[reverb.ReplaySample], observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types.TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, ): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset_iterator: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module(observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. self._iterator = dataset_iterator # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='d4pg_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }) critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': self._policy_network, 'critic': critic_mean, }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, policy_network: snt.Module, critic_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset: tf.data.Dataset, critic_optimizer: snt.Optimizer = None, critic_lr: float = 1e-4, checkpoint_interval_minutes: int = 10.0, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, init_observations: Any = None, distributional: bool = True, vmin: Optional[float] = None, vmax: Optional[float] = None, ): self._policy_network = policy_network self._critic_network = critic_network self._target_critic_network = target_critic_network self._discount = discount self._clipping = clipping self._init_observations = init_observations self._distributional = distributional self._vmin = vmin self._vmax = vmax if (self._distributional and (self._vmin is not None or self._vmax is not None)): logging.warning('vmin and vmax arguments to FQELearner are ignored when ' 'distributional=True. They should be provided when ' 'creating the critic network.') # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. self._iterator = iter(dataset) # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(critic_lr) # Expose the variables. self._variables = { 'critic': self._target_critic_network.variables, } if distributional: critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) else: # We remove trailing dimensions to keep same output dimmension # as existing FQE based on D4PG. i.e.: (batch_size,). critic_mean = snt.Sequential( [self._critic_network, lambda t: tf.squeeze(t, -1)]) self._critic_mean = critic_mean # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save=self.state, time_delta_minutes=checkpoint_interval_minutes, checkpoint_ttl_seconds=_CHECKPOINT_TTL) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'critic': critic_mean, }, time_delta_minutes=60.)
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset_iterator: Iterator[reverb.ReplaySample], prior_network: Optional[snt.Module] = None, target_prior_network: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, prior_optimizer: Optional[snt.Optimizer] = None, distillation_cost: Optional[float] = 1e-3, entropy_regularizer_cost: Optional[float] = 1e-3, num_action_samples: int = 10, lambda_: float = 1.0, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset_iterator: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation). prior_network: the online (optimized) prior. target_prior_network: the target prior (which lags behind the online prior). policy_optimizer: the optimizer to be applied to the SVG-0 (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. prior_optimizer: the optimizer to be applied to the prior (distillation) loss. distillation_cost: a multiplier to be used when adding distillation against the prior to the losses. entropy_regularizer_cost: a multiplier used for per state sample based entropy added to the actor loss. num_action_samples: the number of action samples to use for estimating the value function and sample based entropy. lambda_: the `lambda` value to be used with retrace. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network self._prior_network = prior_network self._target_prior_network = target_prior_network self._lambda = lambda_ self._num_action_samples = num_action_samples self._distillation_cost = distillation_cost self._entropy_regularizer_cost = entropy_regularizer_cost # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. self._iterator = dataset_iterator # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) self._prior_optimizer = prior_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. self._variables = { 'critic': self._critic_network.variables, 'policy': self._policy_network.variables, } if self._prior_network is not None: self._variables['prior'] = self._prior_network.variables # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: objects_to_save = { 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, } if self._prior_network is not None: objects_to_save['prior'] = self._prior_network objects_to_save['target_prior'] = self._target_prior_network objects_to_save['prior_optimizer'] = self._prior_optimizer self._checkpointer = tf2_savers.Checkpointer( subdirectory='svg0_learner', objects_to_save=objects_to_save) objects_to_snapshot = { 'policy': self._policy_network, 'critic': self._critic_network, } if self._prior_network is not None: objects_to_snapshot['prior'] = self._prior_network self._snapshotter = tf2_savers.Snapshotter( objects_to_save=objects_to_snapshot) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__( self, reward_objectives: Sequence[RewardObjective], qvalue_objectives: Sequence[QValueObjective], policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, num_samples: int, target_policy_update_period: int, target_critic_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = tf.identity, target_observation_network: types.TensorTransformation = tf.identity, policy_loss_module: Optional[losses.MultiObjectiveMPO] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, dual_optimizer: Optional[snt.Optimizer] = None, clipping: bool = True, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._num_samples = num_samples self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_policy_update_period = target_policy_update_period self._target_critic_update_period = target_critic_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Store objectives self._reward_objectives = reward_objectives self._qvalue_objectives = qvalue_objectives if self._qvalue_objectives is None: self._qvalue_objectives = [] self._num_critic_heads = len(self._reward_objectives) # C self._objective_names = ([x.name for x in self._reward_objectives] + [x.name for x in self._qvalue_objectives]) self._policy_loss_module = policy_loss_module or losses.MultiObjectiveMPO( epsilons=[ losses.KLConstraint(name, _DEFAULT_EPSILON) for name in self._objective_names ], epsilon_mean=_DEFAULT_EPSILON_MEAN, epsilon_stddev=_DEFAULT_EPSILON_STDDEV, init_log_temperature=_DEFAULT_INIT_LOG_TEMPERATURE, init_log_alpha_mean=_DEFAULT_INIT_LOG_ALPHA_MEAN, init_log_alpha_stddev=_DEFAULT_INIT_LOG_ALPHA_STDDEV) # Check that ordering of objectives matches the policy_loss_module's if self._objective_names != list( self._policy_loss_module.objective_names): raise ValueError("Agent's ordering of objectives doesn't match " "the policy loss module's ordering of epsilons.") # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='dmpo_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'dual_optimizer': self._dual_optimizer, 'policy_loss_module': self._policy_loss_module, 'num_steps': self._num_steps, }) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': snt.Sequential([ self._target_observation_network, self._target_policy_network ]), }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp: float = None
def __init__(self, policy_network: snt.Module, critic_network: snt.Module, dataset: tf.data.Dataset, discount: float, behavior_network: Optional[snt.Module] = None, cwp_network: Optional[snt.Module] = None, policy_optimizer: Optional[ snt.Optimizer] = snt.optimizers.Adam(1e-4), critic_optimizer: Optional[ snt.Optimizer] = snt.optimizers.Adam(1e-4), target_update_period: int = 100, policy_improvement_modes: str = 'exp', ratio_upper_bound: float = 20., beta: float = 1.0, cql_alpha: float = 0.0, translate_lse: float = 100., empirical_policy: dict = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint_subpath: str = '~/acme/'): """Initializes the learner. Args: network: the online Q network (the one being optimized) target_network: the target Q critic (which lags behind the online net). discount: discount to use for TD updates. importance_sampling_exponent: power to which importance weights are raised before normalizing. learning_rate: learning rate for the q-network update. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). huber_loss_parameter: Quadratic-linear boundary for Huber loss. replay_client: client to replay to allow for updating priorities. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. """ self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network # Create a target networks. self._target_policy_network = copy.deepcopy(policy_network) self._target_critic_network = copy.deepcopy(critic_network) self._critic_optimizer = critic_optimizer self._policy_optimizer = policy_optimizer self._target_update_period = target_update_period # Internalise the hyperparameters. self._discount = discount self._target_update_period = target_update_period # crr specific assert policy_improvement_modes in [ 'exp', 'binary', 'all' ], 'Policy imp. mode must be one of {exp, binary, all}' self._policy_improvement_modes = policy_improvement_modes self._beta = beta self._ratio_upper_bound = ratio_upper_bound # cql specific self._alpha = tf.constant(cql_alpha, dtype=tf.float32) self._tr = tf.constant(translate_lse, dtype=tf.float32) if cql_alpha: assert empirical_policy is not None, 'Empirical behavioural policy must be specified with non-zero cql_alpha.' self._emp_policy = empirical_policy # Learner state. # Expose the variables. self._variables = { 'critic': self._target_critic_network.variables, 'policy': self._target_policy_network.variables, } # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._counter.increment(learner_steps=0) self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Create a checkpointer and snapshoter object. self._checkpointer = tf2_savers.Checkpointer( objects_to_save=self.state, time_delta_minutes=10., directory=checkpoint_subpath, subdirectory='crr_learner') objects_to_save = { 'raw_policy': policy_network, 'critic': critic_network, } self._snapshotter = tf2_savers.Snapshotter( objects_to_save=objects_to_save, time_delta_minutes=10) # Timestamp to keep track of the wall time. self._walltime_timestamp = time.time()
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, counter: counting.Counter = None, logger: loggers.Logger = None, snapshot_dir: Optional[str] = None, n_step_horizon: int = 16, minibatch_size: int = 80, learning_rate: float = 2e-3, discount: float = 0.99, gae_lambda: float = 0.99, decay: float = 0.99, epsilon: float = 1e-5, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, max_queue_size: int = 100000, verbose_level: Optional[int] = 0, ): # Internalize spec and replay buffer. self._environment_spec = environment_spec self._verbose_level = verbose_level self._minibatch_size = minibatch_size self._queue = queue.ReplayBuffer(max_queue_size=max_queue_size, batch_size=n_step_horizon) # Internalize network. self._network = network # Setup optimizer and learning rate scheduler. self._learning_rate = tf.Variable(learning_rate, trainable=False) self._lr_scheduler = schedules.ExponentialDecay( initial_learning_rate=learning_rate, decay_steps=8000, # TODO make a flag decay_rate=0.96 # TODO make a flag ) self._optimizer = snt.optimizers.RMSProp( learning_rate=self._learning_rate, decay=decay, epsilon=epsilon) #self._optimizer = snt.optimizers.Adam( # learning_rate=self._learning_rate, #) # Hyperparameters. self._discount = discount self._gae_lambda = gae_lambda self._entropy_cost = entropy_cost self._baseline_cost = baseline_cost # Set up reward/gradient clipping. if max_abs_reward is None: max_abs_reward = np.inf if max_gradient_norm is None: max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. self._max_abs_reward = tf.convert_to_tensor(max_abs_reward) self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) if snapshot_dir is not None: self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': self._network}, directory=snapshot_dir, time_delta_minutes=60.) # Logger. self._counter = counter or counting.Counter() self._logger = logger # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None self._pi_old = None
def __init__(self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, target_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types. TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, critic_optimizer: snt.Optimizer = None, clipping: bool = True, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, specified_path: str = None): # print('\033[94m I am sub_virtual acme d4pg learning\033[0m') """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. policy_optimizer: the optimizer to be applied to the DPG (policy) loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. clipping: whether to clip gradients by global norm. counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_update_period = target_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter objects. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='d4pg_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }) # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/efc0f104-d4a6-11eb-9d04-04d4c40103a8/checkpoints/d4pg_learner/checkpoint/ckpt-1') # self._checkpointer._checkpoint.restore('/home/argsubt/acme/f397d4d6-edf2-11eb-a739-04d4c40103a8/checkpoints/d4pg_learner/ckpt-1') # print('\033[92mload checkpoints~\033[0m') # self._checkpointer._checkpoint.restore('/home/argsubt/subt-virtual/acme_logs/4346ec84-ee10-11eb-8185-04d4c40103a8/checkpoints/d4pg_learner/ckpt-532') self.specified_path = specified_path if self.specified_path != None: self._checkpointer._checkpoint.restore(self.specified_path) print('\033[92mspecified_path: ', str(self.specified_path), '\033[0m') critic_mean = snt.Sequential( [self._critic_network, acme_nets.StochasticMeanHead()]) self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'policy': self._policy_network, 'critic': critic_mean, }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, network: snt.Module, discount: float, importance_sampling_exponent: float, learning_rate: float, target_update_period: int, cql_alpha: float, dataset: tf.data.Dataset, huber_loss_parameter: float = 1., empirical_policy: dict = None, translate_lse: float = 100., replay_client: reverb.TFClient = None, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint_subpath: str = '~/acme/'): """Initializes the learner. Args: network: the online Q network (the one being optimized) target_network: the target Q critic (which lags behind the online net). discount: discount to use for TD updates. importance_sampling_exponent: power to which importance weights are raised before normalizing. learning_rate: learning rate for the q-network update. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). huber_loss_parameter: Quadratic-linear boundary for Huber loss. replay_client: client to replay to allow for updating priorities. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. """ # Internalise agent components (replay buffer, networks, optimizer). self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._network = network self._target_network = copy.deepcopy(network) self._optimizer = snt.optimizers.Adam(learning_rate) self._alpha = tf.constant(cql_alpha, dtype=tf.float32) self._tr = tf.constant(translate_lse, dtype=tf.float32) self._emp_policy = empirical_policy self._replay_client = replay_client # Internalise the hyperparameters. self._discount = discount self._target_update_period = target_update_period self._importance_sampling_exponent = importance_sampling_exponent self._huber_loss_parameter = huber_loss_parameter # Learner state. self._variables: List[List[tf.Tensor]] = [network.trainable_variables] # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._counter.increment(learner_steps=0) self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Create a checkpointer and snapshotter object. self._checkpointer = tf2_savers.Checkpointer( objects_to_save=self.state, time_delta_minutes=10., directory=checkpoint_subpath, subdirectory='cql_learner') self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=30.) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__( self, network: snt.Module, target_network: snt.Module, discount: float, importance_sampling_exponent: float, learning_rate: float, target_update_period: int, dataset: tf.data.Dataset, huber_loss_parameter: float = 1., replay_client: Optional[Union[reverb.Client, reverb.TFClient]] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, save_directory: str = '~/acme', max_gradient_norm: Optional[float] = None, ): """Initializes the learner. Args: network: the online Q network (the one being optimized) target_network: the target Q critic (which lags behind the online net). discount: discount to use for TD updates. importance_sampling_exponent: power to which importance weights are raised before normalizing. learning_rate: learning rate for the q-network update. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation). huber_loss_parameter: Quadratic-linear boundary for Huber loss. replay_client: client to replay to allow for updating priorities. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. save_directory: string indicating where the learner should save checkpoints and snapshots. max_gradient_norm: used for gradient clipping. """ # TODO(mwhoffman): stop allowing replay_client to be passed as a TFClient. # This is just here for backwards compatability for agents which reuse this # Learner and still pass a TFClient instance. if isinstance(replay_client, reverb.TFClient): # TODO(b/170419518): open source pytype does not understand this # isinstance() check because it does not have a way of getting precise # type information for pip-installed packages. replay_client = reverb.Client(replay_client._server_address) # pytype: disable=attribute-error # Internalise agent components (replay buffer, networks, optimizer). # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._network = network self._target_network = target_network self._optimizer = snt.optimizers.Adam(learning_rate) self._replay_client = replay_client # Make sure to initialize the optimizer so that its variables (e.g. the Adam # moments) are included in the state returned by the learner (which can then # be checkpointed and restored). self._optimizer._initialize(network.trainable_variables) # pylint: disable= protected-access # Internalise the hyperparameters. self._discount = discount self._target_update_period = target_update_period self._importance_sampling_exponent = importance_sampling_exponent self._huber_loss_parameter = huber_loss_parameter if max_gradient_norm is None: max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) # Learner state. self._variables: List[List[tf.Tensor]] = [network.trainable_variables] self._num_steps = tf.Variable(0, dtype=tf.int32) # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Create a snapshotter object. if checkpoint: self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, directory=save_directory, time_delta_minutes=60.) else: self._snapshotter = None # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, policy_network: snt.Module, critic_network: snt.Module, f_network: snt.Module, discount: float, dataset: tf.data.Dataset, use_tilde_critic: bool, tilde_critic_network: snt.Module = None, tilde_critic_update_period: int = None, critic_optimizer_class: str = 'OAdam', critic_lr: float = 1e-4, critic_beta1: float = 0.5, critic_beta2: float = 0.9, f_optimizer_class: str = 'OAdam', f_lr: float = None, # Either f_lr or f_lr_multiplier must be # None. f_lr_multiplier: Optional[float] = 1.0, f_beta1: float = 0.5, f_beta2: float = 0.9, critic_regularizer: float = 0.0, f_regularizer: float = 1.0, # Ignored if use_tilde_critic = True critic_ortho_regularizer: float = 0.0, f_ortho_regularizer: float = 0.0, critic_l2_regularizer: float = 0.0, f_l2_regularizer: float = 0.0, checkpoint_interval_minutes: int = 10.0, clipping: bool = True, clipping_action: bool = True, bre_check_period: int = 0, # Bellman residual error check. bre_check_num_actions: int = 0, # Number of sampled actions. dev_dataset: tf.data.Dataset = None, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True): self._policy_network = policy_network self._critic_network = critic_network self._f_network = f_network self._discount = discount self._clipping = clipping self._clipping_action = clipping_action self._bre_check_period = bre_check_period self._bre_check_num_actions = bre_check_num_actions # Development dataset for hyper-parameter selection. self._dev_dataset = dev_dataset self._dev_actions_dataset = self._sample_actions() # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Necessary to track when to update tilde critic networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._use_tilde_critic = use_tilde_critic self._tilde_critic_network = tilde_critic_network self._tilde_critic_update_period = tilde_critic_update_period if use_tilde_critic and tilde_critic_update_period is None: raise ValueError('tilde_critic_update_period must be provided if ' 'use_tilde_critic is True.') # Batch dataset and create iterator. self._iterator = iter(dataset) # Create optimizers if they aren't given. self._critic_optimizer = _optimizer_class(critic_optimizer_class)( critic_lr, beta1=critic_beta1, beta2=critic_beta2) if f_lr is not None: if f_lr_multiplier is not None: raise ValueError(f'Either f_lr ({f_lr}) or f_lr_multiplier ' f'({f_lr_multiplier}) must be None.') else: f_lr = f_lr_multiplier * critic_lr # Prevent unreasonable value in hyper-param search. f_lr = max(min(f_lr, 1e-2), critic_lr) self._f_optimizer = _optimizer_class(f_optimizer_class)( f_lr, beta1=f_beta1, beta2=f_beta2) # Regularization on network values. self._critic_regularizer = critic_regularizer self._f_regularizer = f_regularizer # Orthogonal regularization strength. self._critic_ortho_regularizer = critic_ortho_regularizer self._f_ortho_regularizer = f_ortho_regularizer # L2 regularization strength. self._critic_l2_regularizer = critic_l2_regularizer self._f_l2_regularizer = f_l2_regularizer # Expose the variables. self._variables = { 'critic': self._critic_network.variables, } # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: objects_to_save = { 'counter': self._counter, 'critic': self._critic_network, 'f': self._f_network, 'tilde_critic': self._tilde_critic_network, 'critic_optimizer': self._critic_optimizer, 'f_optimizer': self._f_optimizer, 'num_steps': self._num_steps, } self._checkpointer = tf2_savers.Checkpointer( objects_to_save={k: v for k, v in objects_to_save.items() if v is not None}, time_delta_minutes=checkpoint_interval_minutes, checkpoint_ttl_seconds=_CHECKPOINT_TTL) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'critic': self._critic_network, 'f': self._f_network, }, time_delta_minutes=60.)
def __init__(self, value_func: snt.Module, instrumental_feature: snt.Module, policy_net: snt.Module, discount: float, value_learning_rate: float, instrumental_learning_rate: float, value_l2_reg: float, instrumental_l2_reg: float, stage1_reg: float, stage2_reg: float, instrumental_iter: int, value_iter: int, dataset: tf.data.Dataset, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True): """Initializes the learner. Args: value_feature: value function network instrumental_feature: dual function network. policy_net: policy network. discount: global discount. value_learning_rate: learning rate for the treatment_net update. instrumental_learning_rate: learning rate for the instrumental_net update. value_l2_reg: l2 reg for value feature instrumental_l2_reg: l2 reg for instrumental stage1_reg: ridge regularizer for stage 1 regression stage2_reg: ridge regularizer for stage 2 regression instrumental_iter: number of iteration for instrumental net value_iter: number of iteration for value function, dataset: dataset to learn from. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. """ self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self.stage1_reg = stage1_reg self.stage2_reg = stage2_reg self.instrumental_iter = instrumental_iter self.value_iter = value_iter self.discount = discount self.value_l2_reg = value_l2_reg self.instrumental_reg = instrumental_l2_reg # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self.value_func = value_func self.value_feature = value_func._feature self.instrumental_feature = instrumental_feature self.policy = policy_net self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate) self._instrumental_func_optimizer = snt.optimizers.Adam( instrumental_learning_rate) self._variables = [ value_func.trainable_variables, instrumental_feature.trainable_variables, ] self._num_steps = tf.Variable(0, dtype=tf.int32) self.data = None # Create a snapshotter object. if checkpoint: self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'value_func': value_func, 'instrumental_feature': instrumental_feature, }, time_delta_minutes=60.) else: self._snapshotter = None
def __init__( self, network: networks.IQNNetwork, target_network: snt.Module, discount: float, importance_sampling_exponent: float, learning_rate: float, target_update_period: int, dataset: tf.data.Dataset, huber_loss_parameter: float = 1., replay_client: Optional[reverb.TFClient] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): """Initializes the learner. Args: network: the online Q network (the one being optimized) that outputs (q_values, q_logits, atoms). target_network: the target Q critic (which lags behind the online net). discount: discount to use for TD updates. importance_sampling_exponent: power to which importance weights are raised before normalizing. learning_rate: learning rate for the q-network update. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation). huber_loss_parameter: Quadratic-linear boundary for Huber loss. replay_client: client to replay to allow for updating priorities. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner or not. """ # Internalise agent components (replay buffer, networks, optimizer). self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._network = network self._target_network = target_network self._optimizer = snt.optimizers.Adam(learning_rate) self._replay_client = replay_client # Internalise the hyperparameters. self._discount = discount self._target_update_period = target_update_period self._importance_sampling_exponent = importance_sampling_exponent self._huber_loss_parameter = huber_loss_parameter # Learner state. self._variables: List[List[tf.Tensor]] = [network.trainable_variables] self._num_steps = tf.Variable(0, dtype=tf.int32) # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Create a snapshotter object. if checkpoint: self._checkpointer = tf2_savers.Checkpointer( time_delta_minutes=5, objects_to_save={ 'network': self._network, 'target_network': self._target_network, 'optimizer': self._optimizer, 'num_steps': self._num_steps }) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) else: self._checkpointer = None self._snapshotter = None
def __init__( self, policy_network: snt.Module, critic_network: snt.Module, target_policy_network: snt.Module, target_critic_network: snt.Module, discount: float, num_samples: int, target_policy_update_period: int, target_critic_update_period: int, dataset: tf.data.Dataset, observation_network: types.TensorTransformation = tf.identity, target_observation_network: types.TensorTransformation = tf.identity, policy_loss_module: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, dual_optimizer: Optional[snt.Optimizer] = None, clipping: bool = True, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = True, ): # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Make sure observation networks are snt.Module's so they have variables. self._observation_network = tf2_utils.to_sonnet_module( observation_network) self._target_observation_network = tf2_utils.to_sonnet_module( target_observation_network) # General learner book-keeping and loggers. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner') # Other learner parameters. self._discount = discount self._num_samples = num_samples self._clipping = clipping # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) self._target_policy_update_period = target_policy_update_period self._target_critic_update_period = target_critic_update_period # Batch dataset and create iterator. # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._policy_loss_module = policy_loss_module or losses.MPO( epsilon=1e-1, epsilon_penalty=1e-3, epsilon_mean=1e-3, epsilon_stddev=1e-6, init_log_temperature=1., init_log_alpha_mean=1., init_log_alpha_stddev=10.) # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) self._dual_optimizer = dual_optimizer or snt.optimizers.Adam(1e-2) # Expose the variables. policy_network_to_expose = snt.Sequential( [self._target_observation_network, self._target_policy_network]) self._variables = { 'critic': self._target_critic_network.variables, 'policy': policy_network_to_expose.variables, } # Create a checkpointer and snapshotter object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( subdirectory='dmpo_learner', objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'observation': self._observation_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'target_observation': self._target_observation_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'dual_optimizer': self._dual_optimizer, 'policy_loss_module': self._policy_loss_module, 'num_steps': self._num_steps, }) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={ 'policy': snt.Sequential([ self._target_observation_network, self._target_policy_network ]), }) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__(self, value_func: snt.Module, mixture_density: snt.Module, policy_net: snt.Module, discount: float, value_learning_rate: float, density_learning_rate: float, n_sampling: int, density_iter: int, dataset: tf.data.Dataset, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, checkpoint_interval_minutes: int = 10.0): """Initializes the learner. Args: value_func: value function network mixture_density: mixture density function network. policy_net: policy network. discount: global discount. value_learning_rate: learning rate for the treatment_net update. density_learning_rate: learning rate for the mixture_density update. n_sampling: number of samples generated in stage 2, density_iter: number of iteration for mixture_density function, dataset: dataset to learn from. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_interval_minutes: checkpoint interval in minutes. """ self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self.density_iter = density_iter self.n_sampling = n_sampling self.discount = discount # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self.value_func = value_func self.mixture_density = mixture_density self.policy = policy_net self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate) self._mixture_density_optimizer = snt.optimizers.Adam( density_learning_rate) self._variables = [ value_func.trainable_variables, mixture_density.trainable_variables, ] self._num_steps = tf.Variable(0, dtype=tf.int32) self._mse = tf.keras.losses.MeanSquaredError() # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save=self.state, time_delta_minutes=checkpoint_interval_minutes, checkpoint_ttl_seconds=_CHECKPOINT_TTL) self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'value_func': value_func, 'mixture_density': mixture_density, }, time_delta_minutes=60.)
def __init__(self, policy_network: snt.RNNCore, critic_network: networks.CriticDeepRNN, target_policy_network: snt.RNNCore, target_critic_network: networks.CriticDeepRNN, dataset: tf.data.Dataset, accelerator_strategy: Optional[tf.distribute.Strategy] = None, behavior_network: Optional[snt.Module] = None, cwp_network: Optional[snt.Module] = None, policy_optimizer: Optional[snt.Optimizer] = None, critic_optimizer: Optional[snt.Optimizer] = None, discount: float = 0.99, target_update_period: int = 100, num_action_samples_td_learning: int = 1, num_action_samples_policy_weight: int = 4, baseline_reduce_function: str = 'mean', clipping: bool = True, policy_improvement_modes: str = 'exp', ratio_upper_bound: float = 20., beta: float = 1.0, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, checkpoint: bool = False): """Initializes the learner. Args: policy_network: the online (optimized) policy. critic_network: the online critic. target_policy_network: the target policy (which lags behind the online policy). target_critic_network: the target critic. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_reverb_dataset` documentation). accelerator_strategy: the strategy used to distribute computation, whether on a single, or multiple, GPU or TPU; as supported by tf.distribute. behavior_network: The network to snapshot under `policy` name. If None, snapshots `policy_network` instead. cwp_network: CWP network to snapshot: samples actions from the policy and weighs them with the critic, then returns the action by sampling from the softmax distribution using critic values as logits. Used only for snapshotting, not training. policy_optimizer: the optimizer to be applied to the policy loss. critic_optimizer: the optimizer to be applied to the distributional Bellman loss. discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. num_action_samples_td_learning: number of action samples to use to estimate expected value of the critic loss w.r.t. stochastic policy. num_action_samples_policy_weight: number of action samples to use to estimate the advantage function for the CRR weighting of the policy loss. baseline_reduce_function: one of 'mean', 'max', 'min'. Way of aggregating values from `num_action_samples` estimates of the value function. clipping: whether to clip gradients by global norm. policy_improvement_modes: one of 'exp', 'binary', 'all'. CRR mode which determines how the advantage function is processed before being multiplied by the policy loss. ratio_upper_bound: if policy_improvement_modes is 'exp', determines the upper bound of the weight (i.e. the weight is min(exp(advantage / beta), upper_bound) ). beta: if policy_improvement_modes is 'exp', determines the beta (see above). counter: counter object used to keep track of steps. logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. """ if accelerator_strategy is None: accelerator_strategy = snt.distribute.Replicator() self._accelerator_strategy = accelerator_strategy self._policy_improvement_modes = policy_improvement_modes self._ratio_upper_bound = ratio_upper_bound self._num_action_samples_td_learning = num_action_samples_td_learning self._num_action_samples_policy_weight = num_action_samples_policy_weight self._baseline_reduce_function = baseline_reduce_function self._beta = beta # When running on TPUs we have to know the amount of memory required (and # thus the sequence length) at the graph compilation stage. At the moment, # the only way to get it is to sample from the dataset, since the dataset # does not have any metadata, see b/160672927 to track this upcoming # feature. sample = next(dataset.as_numpy_iterator()) self._sequence_length = sample.action.shape[1] self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self._discount = discount self._clipping = clipping self._target_update_period = target_update_period with self._accelerator_strategy.scope(): # Necessary to track when to update target networks. self._num_steps = tf.Variable(0, dtype=tf.int32) # (Maybe) distributing the dataset across multiple accelerators. distributed_dataset = self._accelerator_strategy.experimental_distribute_dataset( dataset) self._iterator = iter(distributed_dataset) # Create the optimizers. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam( 1e-4) self._policy_optimizer = policy_optimizer or snt.optimizers.Adam( 1e-4) # Store online and target networks. self._policy_network = policy_network self._critic_network = critic_network self._target_policy_network = target_policy_network self._target_critic_network = target_critic_network # Expose the variables. self._variables = { 'critic': self._target_critic_network.variables, 'policy': self._target_policy_network.variables, } # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save={ 'counter': self._counter, 'policy': self._policy_network, 'critic': self._critic_network, 'target_policy': self._target_policy_network, 'target_critic': self._target_critic_network, 'policy_optimizer': self._policy_optimizer, 'critic_optimizer': self._critic_optimizer, 'num_steps': self._num_steps, }, time_delta_minutes=30.) raw_policy = snt.DeepRNN( [policy_network, networks.StochasticSamplingHead()]) critic_mean = networks.CriticDeepRNN( [critic_network, networks.StochasticMeanHead()]) objects_to_save = { 'raw_policy': raw_policy, 'critic': critic_mean, } if behavior_network is not None: objects_to_save['policy'] = behavior_network if cwp_network is not None: objects_to_save['cwp_policy'] = cwp_network self._snapshotter = tf2_savers.Snapshotter( objects_to_save=objects_to_save, time_delta_minutes=30) # Timestamp to keep track of the wall time. self._walltime_timestamp = time.time()
def __init__( self, network: snt.Module, target_network: snt.Module, discount: float, importance_sampling_exponent: float, learning_rate: float, target_update_period: int, dataset: tf.data.Dataset, huber_loss_parameter: float = 1., replay_client: reverb.TFClient = None, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, max_gradient_norm: float = None, ): """Initializes the learner. Args: network: the online Q network (the one being optimized) target_network: the target Q critic (which lags behind the online net). discount: discount to use for TD updates. importance_sampling_exponent: power to which importance weights are raised before normalizing. learning_rate: learning rate for the q-network update. target_update_period: number of learner steps to perform before updating the target networks. dataset: dataset to learn from, whether fixed or from a replay buffer (see `acme.datasets.reverb.make_dataset` documentation). huber_loss_parameter: Quadratic-linear boundary for Huber loss. replay_client: client to replay to allow for updating priorities. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. max_gradient_norm: used for gradient clipping. """ # Internalise agent components (replay buffer, networks, optimizer). # TODO(b/155086959): Fix type stubs and remove. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self._network = network self._target_network = target_network self._optimizer = snt.optimizers.Adam(learning_rate) self._replay_client = replay_client # Internalise the hyperparameters. self._discount = discount self._target_update_period = target_update_period self._importance_sampling_exponent = importance_sampling_exponent self._huber_loss_parameter = huber_loss_parameter if max_gradient_norm is None: max_gradient_norm = 1e10 # A very large number. Infinity results in NaNs. self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) # Learner state. self._variables: List[List[tf.Tensor]] = [network.trainable_variables] self._num_steps = tf.Variable(0, dtype=tf.int32) # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Create a snapshotter object. if checkpoint: self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) else: self._snapshotter = None # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_generator: iter, demonstration_ratio: float, model_directory: str, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, ): extra_spec = { 'core_state': network.initial_state(1), } # replay table # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # demonstation table. demonstration_table = reverb.Table( name='demonstration_table', sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # launch server self._server = reverb.Server([replay_table, demonstration_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay and demo sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) priority_function = {demonstration_table.name: lambda x: 1.} demo_adder = adders.SequenceAdder(client=reverb.Client(address), priority_fns=priority_function, **sequence_kwargs) # play demonstrations and write # exhaust the generator # TODO: MAX REPLAY SIZE _prev_action = 1 # this has to come from spec _add_first = True #include this to make datasets equivalent numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1)) for ts, action in demonstration_generator: if _add_first: demo_adder.add_first(ts) _add_first = False else: demo_adder.add(_prev_action, ts, extras=(numpy_state, )) _prev_action = action # reset to new episode if ts.last(): _prev_action = None _add_first = True # replay dataset max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100 dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator= 2, # memory perf improvment attempt https://github.com/deepmind/acme/issues/33 sequence_length=sequence_length, emit_timesteps=sequence_length is None) # demonstation dataset d_dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=demonstration_table.name, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=2, sequence_length=sequence_length, emit_timesteps=sequence_length is None) dataset = tf.data.experimental.sample_from_datasets( [dataset, d_dataset], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( directory=model_directory, subdirectory='r2d2_learner_v1', time_delta_minutes=15, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None, time_delta_minutes=15000., directory=model_directory) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon: float = 0.01, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. reverb_client = reverb.TFClient(address) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, extra_spec=extra_spec, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = ( float(replay_period * batch_size) / samples_per_insert) super().__init__( actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__(self, value_func: snt.Module, instrumental_feature: snt.Module, policy_net: snt.Module, discount: float, value_learning_rate: float, instrumental_learning_rate: float, value_reg: float, instrumental_reg: float, stage1_reg: float, stage2_reg: float, instrumental_iter: int, value_iter: int, dataset: tf.data.Dataset, d_tm1_weight: float = 1.0, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, checkpoint_interval_minutes: int = 10.0): """Initializes the learner. Args: value_func: value function network instrumental_feature: dual function network. policy_net: policy network. discount: global discount. value_learning_rate: learning rate for the treatment_net update. instrumental_learning_rate: learning rate for the instrumental_net update. value_reg: L2 regularizer for value net. instrumental_reg: L2 regularizer for instrumental net. stage1_reg: ridge regularizer for stage 1 regression stage2_reg: ridge regularizer for stage 2 regression instrumental_iter: number of iteration for instrumental net value_iter: number of iteration for value function, dataset: dataset to learn from. d_tm1_weight: weights for terminal state transitions. Ignored in this variant. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_interval_minutes: checkpoint interval in minutes. """ self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self.stage1_reg = stage1_reg self.stage2_reg = stage2_reg self.instrumental_iter = instrumental_iter self.value_iter = value_iter self.discount = discount self.value_reg = value_reg self.instrumental_reg = instrumental_reg del d_tm1_weight # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self.value_func = value_func self.value_feature = value_func._feature self.instrumental_feature = instrumental_feature self.policy = policy_net self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate, beta1=0.5, beta2=0.9) self._instrumental_func_optimizer = snt.optimizers.Adam( instrumental_learning_rate, beta1=0.5, beta2=0.9) # Define additional variables. self.stage1_weight = tf.Variable( tf.zeros( (instrumental_feature.feature_dim(), value_func.feature_dim()), dtype=tf.float32)) self._num_steps = tf.Variable(0, dtype=tf.int32) self._variables = [ self.value_func.trainable_variables, self.instrumental_feature.trainable_variables, self.stage1_weight, ] # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save=self.state, time_delta_minutes=checkpoint_interval_minutes, checkpoint_ttl_seconds=_CHECKPOINT_TTL) self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'value_func': self.value_func, 'instrumental_feature': self.instrumental_feature, }, time_delta_minutes=60.)
def __init__(self, treatment_net: snt.Module, instrumental_net: snt.Module, policy_net: snt.Module, treatment_learning_rate: float, instrumental_learning_rate: float, policy_learning_rate: float, dataset: tf.data.Dataset, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True): """Initializes the learner. Args: treatment_net: treatment network. instrumental_net: instrumental network. policy_net: policy network. treatment_learning_rate: learning rate for the treatment_net update. instrumental_learning_rate: learning rate for the instrumental_net update. policy_learning_rate: learning rate for the policy_net update. dataset: dataset to learn from. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. """ self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types # TODO(b/155086959): Fix type stubs and remove. self._treatment_net = treatment_net self._instrumental_net = instrumental_net self._policy_net = policy_net self._treatment_optimizer = snt.optimizers.Adam( treatment_learning_rate) self._instrumental_optimizer = snt.optimizers.Adam( instrumental_learning_rate) self._policy_optimizer = snt.optimizers.Adam(policy_learning_rate) self._variables = [ treatment_net.trainable_variables, instrumental_net.trainable_variables, policy_net.trainable_variables, ] self._num_steps = tf.Variable(0, dtype=tf.int32) # Create a snapshotter object. if checkpoint: self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'treatment_net': treatment_net, 'instrumental_net': instrumental_net, 'policy_net': policy_net, }, time_delta_minutes=60.) else: self._snapshotter = None