def tf_optimization( self, states, internals, actions, terminal, reward, next_states=None, next_internals=None ): """ Creates the TensorFlow operations for performing an optimization update step based on the given input states and actions batch. Args: states: Dict of state tensors. internals: List of prior internal state tensors. actions: Dict of action tensors. terminal: Terminal boolean tensor. reward: Reward tensor. next_states: Dict of successor state tensors. next_internals: List of posterior internal state tensors. Returns: The optimization operation. """ parameters_before = OrderedDict() embedding = self.network.apply(x=states, internals=internals) for name, distribution in self.distributions.items(): parameters_before[name] = distribution.parametrize(x=embedding) with tf.control_dependencies(control_inputs=util.flatten(xs=parameters_before)): optimized = super().tf_optimization( states=states, internals=internals, actions=actions, terminal=terminal, reward=reward, next_states=next_states, next_internals=next_internals ) with tf.control_dependencies(control_inputs=(optimized,)): summaries = list() embedding = self.network.apply(x=states, internals=internals) for name, distribution in self.distributions.items(): parameters = distribution.parametrize(x=embedding) kl_divergence = distribution.kl_divergence( parameters1=parameters_before[name], parameters2=parameters ) collapsed_size = util.product(xs=util.shape(kl_divergence)[1:]) kl_divergence = tf.reshape(tensor=kl_divergence, shape=(-1, collapsed_size)) kl_divergence = tf.reduce_mean(input_tensor=kl_divergence, axis=1) kl_divergence = self.add_summary( label='kl-divergence', name=(name + '-kldiv'), tensor=kl_divergence ) summaries.append(kl_divergence) entropy = distribution.entropy(parameters=parameters) entropy = tf.reshape(tensor=entropy, shape=(-1, collapsed_size)) entropy = tf.reduce_mean(input_tensor=entropy, axis=1) entropy = self.add_summary( label='entropy', name=(name + '-entropy'), tensor=entropy ) summaries.append(entropy) with tf.control_dependencies(control_inputs=summaries): return util.no_operation()
def tf_core_experience(self, states, internals, auxiliaries, actions, terminal, reward): zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) # Enqueue experience for early reward estimation any_overwritten, overwritten_values = self.estimator.enqueue( baseline=self.baseline_policy, states=states, internals=internals, auxiliaries=auxiliaries, actions=actions, terminal=terminal, reward=reward ) # If terminal, store remaining values in memory def true_fn(): reset_values = self.estimator.reset(baseline=self.baseline_policy) new_overwritten_values = OrderedDict() for name, value1, value2 in util.zip_items(overwritten_values, reset_values): if util.is_nested(name=name): new_overwritten_values[name] = OrderedDict() for inner_name, value1, value2 in util.zip_items(value1, value2): new_overwritten_values[name][inner_name] = tf.concat( values=(value1, value2), axis=0 ) else: new_overwritten_values[name] = tf.concat(values=(value1, value2), axis=0) return new_overwritten_values def false_fn(): return overwritten_values with tf.control_dependencies(control_inputs=util.flatten(xs=overwritten_values)): values = self.cond(pred=(terminal[-1] > zero), true_fn=true_fn, false_fn=false_fn) # If any, store overwritten values def store(): return self.memory.enqueue(**values) terminal = values['terminal'] if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_values = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0] else: num_values = tf.dtypes.cast( x=tf.shape(input=terminal)[0], dtype=util.tf_dtype(dtype='long') ) stored = self.cond(pred=(num_values > zero), true_fn=store, false_fn=util.no_operation) return stored
def tf_reset(self): # Constants zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) if not self.return_overwritten: # Reset buffer index assignment = self.buffer_index.assign(value=zero, read_value=False) # Return no-op with tf.control_dependencies(control_inputs=(assignment, )): return util.no_operation() # Overwritten buffer indices num_values = tf.minimum(x=self.buffer_index, y=capacity) indices = tf.range(start=(self.buffer_index - num_values), limit=self.buffer_index) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): values[name] = OrderedDict() for inner_name, buffer in buffer.items(): values[name][inner_name] = tf.gather(params=buffer, indices=indices) else: values[name] = tf.gather(params=buffer, indices=indices) # Reset buffer index with tf.control_dependencies(control_inputs=util.flatten(xs=values)): assignment = self.buffer_index.assign(value=zero, read_value=False) # Return overwritten values with tf.control_dependencies(control_inputs=(assignment, )): return util.fmap(function=util.identity_operation, xs=values)
def tf_optimize(self, indices): # Baseline optimization if self.baseline_optimizer is not None: optimized = self.optimize_baseline(indices=indices) dependencies = (optimized,) else: dependencies = (indices,) # Reward estimation with tf.control_dependencies(control_inputs=dependencies): reward = self.memory.retrieve(indices=indices, values='reward') reward = self.estimator.complete( baseline=self.baseline_policy, memory=self.memory, indices=indices, reward=reward ) reward = self.add_summary( label=('empirical-reward', 'rewards'), name='empirical-reward', tensor=reward ) is_baseline_optimized = self.separate_baseline_policy and \ self.baseline_optimizer is None and self.baseline_objective is None reward = self.estimator.estimate( baseline=self.baseline_policy, memory=self.memory, indices=indices, reward=reward, is_baseline_optimized=is_baseline_optimized ) reward = self.add_summary( label=('estimated-reward', 'rewards'), name='estimated-reward', tensor=reward ) # Stop gradients of estimated rewards if separate baseline optimization if not is_baseline_optimized: reward = tf.stop_gradient(input=reward) # Retrieve states, internals and actions past_horizon = self.policy.past_horizon(is_optimization=True) if self.separate_baseline_policy and self.baseline_optimizer is None: assertion = tf.debugging.assert_equal( x=past_horizon, y=self.baseline_policy.past_horizon(is_optimization=True), message="Policy and baseline depend on a different number of previous states." ) else: assertion = past_horizon with tf.control_dependencies(control_inputs=(assertion,)): # horizon change: see timestep-based batch sampling starts, lengths, states, internals = self.memory.predecessors( indices=indices, horizon=past_horizon, sequence_values='states', initial_values='internals' ) Module.update_tensors(dependency_starts=starts, dependency_lengths=lengths) auxiliaries, actions = self.memory.retrieve( indices=indices, values=('auxiliaries', 'actions') ) # Optimizer arguments independent = Module.update_tensor( name='independent', tensor=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool')) ) variables = self.get_variables(only_trainable=True) arguments = dict( states=states, internals=internals, auxiliaries=auxiliaries, actions=actions, reward=reward ) fn_loss = self.total_loss def fn_kl_divergence(states, internals, auxiliaries, actions, reward, other=None): kl_divergence = self.policy.kl_divergence( states=states, internals=internals, auxiliaries=auxiliaries, other=other ) if self.baseline_optimizer is None and self.baseline_objective is not None: kl_divergence += self.baseline_policy.kl_divergence( states=states, internals=internals, auxiliaries=auxiliaries, other=other ) return kl_divergence if self.global_model is None: global_variables = None else: global_variables = self.global_model.get_variables(only_trainable=True) kwargs = self.objective.optimizer_arguments( policy=self.policy, baseline=self.baseline_policy ) if self.baseline_optimizer is None and self.baseline_objective is not None: util.deep_disjoint_update( target=kwargs, source=self.baseline_objective.optimizer_arguments(policy=self.baseline_policy) ) dependencies = util.flatten(xs=arguments) # KL divergence before if self.is_summary_logged( label=('kl-divergence', 'action-kl-divergences', 'kl-divergences') ): with tf.control_dependencies(control_inputs=dependencies): kldiv_reference = self.policy.kldiv_reference( states=states, internals=internals, auxiliaries=auxiliaries ) dependencies = util.flatten(xs=kldiv_reference) # Optimization with tf.control_dependencies(control_inputs=dependencies): optimized = self.optimizer.minimize( variables=variables, arguments=arguments, fn_loss=fn_loss, fn_kl_divergence=fn_kl_divergence, global_variables=global_variables, **kwargs ) with tf.control_dependencies(control_inputs=(optimized,)): # Loss summaries if self.is_summary_logged(label=('loss', 'objective-loss', 'losses')): objective_loss = self.objective.loss_per_instance(policy=self.policy, **arguments) objective_loss = tf.math.reduce_mean(input_tensor=objective_loss, axis=0) if self.is_summary_logged(label=('objective-loss', 'losses')): optimized = self.add_summary( label=('objective-loss', 'losses'), name='objective-loss', tensor=objective_loss, pass_tensors=optimized ) if self.is_summary_logged(label=('loss', 'regularization-loss', 'losses')): regularization_loss = self.regularize( states=states, internals=internals, auxiliaries=auxiliaries ) if self.is_summary_logged(label=('regularization-loss', 'losses')): optimized = self.add_summary( label=('regularization-loss', 'losses'), name='regularization-loss', tensor=regularization_loss, pass_tensors=optimized ) if self.is_summary_logged(label=('loss', 'losses')): loss = objective_loss + regularization_loss if self.baseline_optimizer is None and self.baseline_objective is not None: if self.is_summary_logged(label=('loss', 'baseline-objective-loss', 'losses')): if self.baseline_objective is None: baseline_objective_loss = self.objective.loss_per_instance( policy=self.baseline_policy, **arguments ) else: baseline_objective_loss = self.baseline_objective.loss_per_instance( policy=self.baseline_policy, **arguments ) baseline_objective_loss = tf.math.reduce_mean( input_tensor=baseline_objective_loss, axis=0 ) if self.is_summary_logged(label=('baseline-objective-loss', 'losses')): optimized = self.add_summary( label=('baseline-objective-loss', 'losses'), name='baseline-objective-loss', tensor=baseline_objective_loss, pass_tensors=optimized ) if self.is_summary_logged( label=('loss', 'baseline-regularization-loss', 'losses') ): baseline_regularization_loss = self.baseline_policy.regularize() if self.is_summary_logged(label=('baseline-regularization-loss', 'losses')): optimized = self.add_summary( label=('baseline-regularization-loss', 'losses'), name='baseline-regularization-loss', tensor=baseline_regularization_loss, pass_tensors=optimized ) if self.is_summary_logged(label=('loss', 'baseline-loss', 'losses')): baseline_loss = baseline_objective_loss + baseline_regularization_loss if self.is_summary_logged(label=('baseline-loss', 'losses')): optimized = self.add_summary( label=('baseline-loss', 'losses'), name='baseline-loss', tensor=baseline_loss, pass_tensors=optimized ) if self.is_summary_logged(label=('loss', 'losses')): loss += self.baseline_loss_weight * baseline_loss if self.is_summary_logged(label=('loss', 'losses')): optimized = self.add_summary( label=('loss', 'losses'), name='loss', tensor=loss, pass_tensors=optimized ) # Entropy summaries if self.is_summary_logged(label=('entropy', 'action-entropies', 'entropies')): entropies = self.policy.entropy( states=states, internals=internals, auxiliaries=auxiliaries, include_per_action=(len(self.actions_spec) > 1) ) if self.is_summary_logged(label=('entropy', 'entropies')): if len(self.actions_spec) == 1: optimized = self.add_summary( label=('entropy', 'entropies'), name='entropy', tensor=entropies, pass_tensors=optimized ) else: optimized = self.add_summary( label=('entropy', 'entropies'), name='entropy', tensor=entropies['*'], pass_tensors=optimized ) if len(self.actions_spec) > 1 and \ self.is_summary_logged(label=('action-entropies', 'entropies')): for name in self.actions_spec: optimized = self.add_summary( label=('action-entropies', 'entropies'), name=(name + '-entropy'), tensor=entropies[name], pass_tensors=optimized ) # KL divergence summaries if self.is_summary_logged( label=('kl-divergence', 'action-kl-divergences', 'kl-divergences') ): kl_divergences = self.policy.kl_divergence( states=states, internals=internals, auxiliaries=auxiliaries, other=kldiv_reference, include_per_action=(len(self.actions_spec) > 1) ) if self.is_summary_logged(label=('kl-divergence', 'kl-divergences')): if len(self.actions_spec) == 1: optimized = self.add_summary( label=('kl-divergence', 'kl-divergences'), name='kl-divergence', tensor=kl_divergences, pass_tensors=optimized ) else: optimized = self.add_summary( label=('kl-divergence', 'kl-divergences'), name='kl-divergence', tensor=kl_divergences['*'], pass_tensors=optimized ) if len(self.actions_spec) > 1 and \ self.is_summary_logged(label=('action-kl-divergences', 'kl-divergences')): for name in self.actions_spec: optimized = self.add_summary( label=('action-kl-divergences', 'kl-divergences'), name=(name + '-kl-divergence'), tensor=kl_divergences[name], pass_tensors=optimized ) Module.update_tensor(name='independent', tensor=independent) return optimized
def tf_enqueue(self, **values): # Constants zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) # Get number of values for value in values.values(): if not isinstance(value, dict): break elif len(value) > 0: value = next(iter(value.values())) break if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_values = tf.shape(input=value, out_type=util.tf_dtype(dtype='long'))[0] else: num_values = tf.dtypes.cast(x=tf.shape(input=value)[0], dtype=util.tf_dtype(dtype='long')) # Check whether instances fit into buffer assertion = tf.debugging.assert_less_equal(x=num_values, y=capacity) if self.return_overwritten: # Overwritten buffer indices with tf.control_dependencies(control_inputs=(assertion, )): start = tf.maximum(x=self.buffer_index, y=capacity) limit = tf.maximum(x=(self.buffer_index + num_values), y=capacity) num_overwritten = limit - start indices = tf.range(start=start, limit=limit) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values with tf.control_dependencies(control_inputs=(indices, )): overwritten_values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): overwritten_values[name] = OrderedDict() for inner_name, buffer in buffer.items(): overwritten_values[name][inner_name] = tf.gather( params=buffer, indices=indices) else: overwritten_values[name] = tf.gather(params=buffer, indices=indices) else: overwritten_values = (assertion, ) # Buffer indices to (over)write with tf.control_dependencies(control_inputs=util.flatten( xs=overwritten_values)): indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_values)) indices = tf.math.mod(x=indices, y=capacity) indices = tf.expand_dims(input=indices, axis=1) # Write new values with tf.control_dependencies(control_inputs=(indices, )): assignments = list() for name, buffer in self.buffers.items(): if util.is_nested(name=name): for inner_name, buffer in buffer.items(): assignment = buffer.scatter_nd_update( indices=indices, updates=values[name][inner_name]) assignments.append(assignment) else: assignment = buffer.scatter_nd_update(indices=indices, updates=values[name]) assignments.append(assignment) # Increment buffer index with tf.control_dependencies(control_inputs=assignments): assignment = self.buffer_index.assign_add(delta=num_values, read_value=False) # Return overwritten values or no-op with tf.control_dependencies(control_inputs=(assignment, )): if self.return_overwritten: any_overwritten = tf.math.greater(x=num_overwritten, y=zero) overwritten_values = util.fmap( function=util.identity_operation, xs=overwritten_values) return any_overwritten, overwritten_values else: return util.no_operation()
def create_api_function(self, name, api_function): # Call API TensorFlow function Module.global_scope = list() Module.scope_stack = list() Module.while_counter = 0 Module.cond_counter = 0 Module.global_tensors = OrderedDict() Module.queryable_tensors = OrderedDict() if self.device is not None: self.device.__enter__() scope = tf.name_scope(name=name) Module.scope_stack.append(scope) scope.__enter__() results = api_function() self.output_tensors[name[name.index('.') + 1:]] = sorted( x.name[len(name) + 1: -9] for x in util.flatten(xs=results) ) # Function-level identity operation for retrieval query_tensors = set() for scoped_name, tensor in Module.queryable_tensors.items(): util.identity_operation(x=tensor, operation_name=(scoped_name + '-output')) assert scoped_name not in query_tensors query_tensors.add(scoped_name) self.query_tensors[name[name.index('.') + 1:]] = sorted(query_tensors) scope.__exit__(None, None, None) Module.scope_stack.pop() if self.device is not None: self.device.__exit__(None, None, None) assert len(Module.global_scope) == 0 Module.global_scope = None assert len(Module.scope_stack) == 0 Module.scope_stack = None Module.while_counter = None Module.cond_counter = None Module.global_tensors = None Module.queryable_tensors = None def fn(query=None, **kwargs): # Feed_dict dictionary feed_dict = dict() for key, arg in kwargs.items(): if arg is None: continue elif isinstance(arg, dict): # Support single nesting (for states, internals, actions) for key, arg in arg.items(): feed_dict[util.join_scopes(self.name, key) + '-input:0'] = arg else: feed_dict[util.join_scopes(self.name, key) + '-input:0'] = arg if not all(isinstance(x, str) and x.endswith('-input:0') for x in feed_dict): raise TensorforceError.value( name=api_function, argument='inputs', value=list(feed_dict) ) # Fetches value/tuple fetches = util.fmap(function=(lambda x: x.name), xs=results) if query is not None: # If additional tensors are to be fetched query = util.fmap( function=(lambda x: util.join_scopes(name, x) + '-output:0'), xs=query ) if util.is_iterable(x=fetches): fetches = tuple(fetches) + (query,) else: fetches = (fetches, query) if not util.reduce_all( predicate=(lambda x: isinstance(x, str) and x.endswith('-output:0')), xs=fetches ): raise TensorforceError.value( name=api_function, argument='outputs', value=list(fetches) ) # TensorFlow session call fetched = self.monitored_session.run(fetches=fetches, feed_dict=feed_dict) return fetched return fn
def tf_enqueue(self, states, internals, auxiliaries, actions, terminal, reward, baseline=None): # Constants and parameters zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) horizon = self.horizon.value() discount = self.discount.value() assertions = list() # Check whether horizon at most capacity assertions.append( tf.debugging.assert_less_equal( x=horizon, y=capacity, message= "Estimator capacity has to be at least the same as the estimation horizon." )) # Check whether at most one terminal assertions.append( tf.debugging.assert_less_equal( x=tf.math.count_nonzero(input=terminal, dtype=util.tf_dtype(dtype='long')), y=one, message="Timesteps contain more than one terminal.")) # Check whether, if any, last value is terminal assertions.append( tf.debugging.assert_equal( x=tf.reduce_any( input_tensor=tf.math.greater(x=terminal, y=zero)), y=tf.math.greater(x=terminal[-1], y=zero), message="Terminal is not the last timestep.")) # Get number of overwritten values with tf.control_dependencies(control_inputs=assertions): if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_values = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0] else: num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0], dtype=util.tf_dtype(dtype='long')) overwritten_start = tf.maximum(x=self.buffer_index, y=capacity) overwritten_limit = tf.maximum(x=(self.buffer_index + num_values), y=capacity) num_overwritten = overwritten_limit - overwritten_start def update_overwritten_rewards(): # Get relevant buffer rewards buffer_limit = self.buffer_index + tf.minimum( x=(num_overwritten + horizon), y=capacity) buffer_indices = tf.range(start=self.buffer_index, limit=buffer_limit) buffer_indices = tf.math.mod(x=buffer_indices, y=capacity) rewards = tf.gather(params=self.buffers['reward'], indices=buffer_indices) # Get relevant values rewards values_limit = tf.maximum(x=(num_overwritten + horizon - capacity), y=zero) rewards = tf.concat(values=(rewards, reward[:values_limit]), axis=0) # Horizon baseline value if self.estimate_horizon == 'early': assert baseline is not None # Baseline estimate buffer_indices = buffer_indices[horizon + one:] _states = OrderedDict() for name, buffer in self.buffers['states'].items(): state = tf.gather(params=buffer, indices=buffer_indices) _states[name] = tf.concat( values=(state, states[name][:values_limit + one]), axis=0) _internals = OrderedDict() for name, buffer in self.buffers['internals'].items(): internal = tf.gather(params=buffer, indices=buffer_indices) _internals[name] = tf.concat( values=(internal, internals[name][:values_limit + one]), axis=0) _auxiliaries = OrderedDict() for name, buffer in self.buffers['auxiliaries'].items(): auxiliary = tf.gather(params=buffer, indices=buffer_indices) _auxiliaries[name] = tf.concat( values=(auxiliary, auxiliaries[name][:values_limit + one]), axis=0) # Dependency horizon # TODO: handle arbitrary non-optimization horizons! past_horizon = baseline.past_horizon(is_optimization=False) assertion = tf.debugging.assert_equal( x=past_horizon, y=zero, message= "Temporary: baseline cannot depend on previous states.") with tf.control_dependencies(control_inputs=(assertion, )): some_state = next(iter(_states.values())) if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): batch_size = tf.shape( input=some_state, out_type=util.tf_dtype(dtype='long'))[0] else: batch_size = tf.dtypes.cast( x=tf.shape(input=some_state)[0], dtype=util.tf_dtype(dtype='long')) starts = tf.range(start=batch_size, dtype=util.tf_dtype(dtype='long')) lengths = tf.ones(shape=(batch_size, ), dtype=util.tf_dtype(dtype='long')) Module.update_tensors(dependency_starts=starts, dependency_lengths=lengths) if self.estimate_actions: _actions = OrderedDict() for name, buffer in self.buffers['actions'].items(): action = tf.gather(params=buffer, indices=buffer_indices) _actions[name] = tf.concat( values=(action, actions[name][:values_limit]), axis=0) horizon_estimate = baseline.actions_value( states=_states, internals=_internals, auxiliaries=_auxiliaries, actions=_actions) else: horizon_estimate = baseline.states_value( states=_states, internals=_internals, auxiliaries=_auxiliaries) else: # Zero estimate horizon_estimate = tf.zeros(shape=(num_overwritten, ), dtype=util.tf_dtype(dtype='float')) # Calculate discounted sum def cond(discounted_sum, horizon): return tf.math.greater_equal(x=horizon, y=zero) def body(discounted_sum, horizon): # discounted_sum = tf.compat.v1.Print( # discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10 # ) discounted_sum = discount * discounted_sum discounted_sum = discounted_sum + rewards[horizon:horizon + num_overwritten] return discounted_sum, horizon - one discounted_sum, _ = self.while_loop(cond=cond, body=body, loop_vars=(horizon_estimate, horizon), back_prop=False) assertions = [ tf.debugging.assert_equal(x=tf.shape(input=horizon_estimate), y=tf.shape(input=discounted_sum), message="Estimation check."), tf.debugging.assert_equal(x=tf.shape( input=rewards, out_type=util.tf_dtype(dtype='long'))[0], y=(horizon + num_overwritten), message="Estimation check.") ] # Overwrite buffer rewards with tf.control_dependencies(control_inputs=assertions): indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_overwritten)) indices = tf.math.mod(x=indices, y=capacity) indices = tf.expand_dims(input=indices, axis=1) assignment = self.buffers['reward'].scatter_nd_update( indices=indices, updates=discounted_sum) with tf.control_dependencies(control_inputs=(assignment, )): return util.no_operation() any_overwritten = tf.math.greater(x=num_overwritten, y=zero) updated_rewards = self.cond(pred=any_overwritten, true_fn=update_overwritten_rewards, false_fn=util.no_operation) # Overwritten buffer indices with tf.control_dependencies(control_inputs=(updated_rewards, )): indices = tf.range(start=overwritten_start, limit=overwritten_limit) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values with tf.control_dependencies(control_inputs=(indices, )): overwritten_values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): overwritten_values[name] = OrderedDict() for inner_name, buffer in buffer.items(): overwritten_values[name][inner_name] = tf.gather( params=buffer, indices=indices) else: overwritten_values[name] = tf.gather(params=buffer, indices=indices) # Buffer indices to (over)write with tf.control_dependencies(control_inputs=util.flatten( xs=overwritten_values)): indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_values)) indices = tf.math.mod(x=indices, y=capacity) indices = tf.expand_dims(input=indices, axis=1) # Write new values with tf.control_dependencies(control_inputs=(indices, )): values = dict(states=states, internals=internals, auxiliaries=auxiliaries, actions=actions, terminal=terminal, reward=reward) assignments = list() for name, buffer in self.buffers.items(): if util.is_nested(name=name): for inner_name, buffer in buffer.items(): assignment = buffer.scatter_nd_update( indices=indices, updates=values[name][inner_name]) assignments.append(assignment) else: assignment = buffer.scatter_nd_update(indices=indices, updates=values[name]) assignments.append(assignment) # Increment buffer index with tf.control_dependencies(control_inputs=assignments): assignment = self.buffer_index.assign_add(delta=num_values, read_value=False) # Return overwritten values or no-op with tf.control_dependencies(control_inputs=(assignment, )): any_overwritten = tf.math.greater(x=num_overwritten, y=zero) overwritten_values = util.fmap(function=util.identity_operation, xs=overwritten_values) return any_overwritten, overwritten_values
def tf_reset(self, baseline=None): # Constants and parameters zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) horizon = self.horizon.value() discount = self.discount.value() # Overwritten buffer indices num_overwritten = tf.minimum(x=self.buffer_index, y=capacity) indices = tf.range(start=(self.buffer_index - num_overwritten), limit=self.buffer_index) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): values[name] = OrderedDict() for inner_name, buffer in buffer.items(): values[name][inner_name] = tf.gather(params=buffer, indices=indices) else: values[name] = tf.gather(params=buffer, indices=indices) states = values['states'] internals = values['internals'] auxiliaries = values['auxiliaries'] actions = values['actions'] terminal = values['terminal'] reward = values['reward'] terminal = values['terminal'] # Reset buffer index with tf.control_dependencies(control_inputs=util.flatten(xs=values)): assignment = self.buffer_index.assign(value=zero, read_value=False) with tf.control_dependencies(control_inputs=(assignment, )): assertions = list() # Check whether exactly one terminal (, unless empty?) assertions.append( tf.debugging.assert_equal( x=tf.math.count_nonzero(input=terminal, dtype=util.tf_dtype(dtype='long')), y=one, message="Timesteps do not contain exactly one terminal.")) # Check whether last value is terminal assertions.append( tf.debugging.assert_equal( x=tf.math.greater(x=terminal[-1], y=zero), y=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool')), message="Terminal is not the last timestep.")) # Get number of values with tf.control_dependencies(control_inputs=assertions): if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_values = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0] else: num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0], dtype=util.tf_dtype(dtype='long')) # Horizon baseline value if self.estimate_horizon == 'early' and baseline is not None: # Dependency horizon # TODO: handle arbitrary non-optimization horizons! past_horizon = baseline.past_horizon(is_optimization=False) assertion = tf.debugging.assert_equal( x=past_horizon, y=zero, message="Temporary: baseline cannot depend on previous states." ) # Baseline estimate horizon_start = num_values - tf.maximum(x=(num_values - horizon), y=one) _states = OrderedDict() for name, state in states.items(): _states[name] = state[horizon_start:] _internals = OrderedDict() for name, internal in internals.items(): _internals[name] = internal[horizon_start:] _auxiliaries = OrderedDict() for name, auxiliary in auxiliaries.items(): _auxiliaries[name] = auxiliary[horizon_start:] with tf.control_dependencies(control_inputs=(assertion, )): # some_state = next(iter(states.values())) # if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): # batch_size = tf.shape(input=some_state, out_type=util.tf_dtype(dtype='long'))[0] # else: # batch_size = tf.dtypes.cast( # x=tf.shape(input=some_state)[0], dtype=util.tf_dtype(dtype='long') # ) batch_size = num_values - horizon_start starts = tf.range(start=batch_size, dtype=util.tf_dtype(dtype='long')) lengths = tf.ones(shape=(batch_size, ), dtype=util.tf_dtype(dtype='long')) Module.update_tensors(dependency_starts=starts, dependency_lengths=lengths) if self.estimate_actions: _actions = OrderedDict() for name, action in actions.items(): _actions[name] = action[horizon_start:] horizon_estimate = baseline.actions_value( states=_states, internals=_internals, auxiliaries=_auxiliaries, actions=_actions) else: horizon_estimate = baseline.states_value( states=_states, internals=_internals, auxiliaries=_auxiliaries) # Expand rewards beyond terminal terminal_zeros = tf.zeros(shape=(horizon, ), dtype=util.tf_dtype(dtype='float')) if self.estimate_terminal: rewards = tf.concat(values=(reward[:-1], horizon_estimate[-1:], terminal_zeros), axis=0) else: with tf.control_dependencies(control_inputs=(assertion, )): last_reward = tf.where(condition=tf.math.greater( x=terminal[-1], y=one), x=horizon_estimate[-1], y=reward[-1]) rewards = tf.concat(values=(reward[:-1], (last_reward, ), terminal_zeros), axis=0) # Remove last if necessary horizon_end = tf.where(condition=tf.math.less_equal(x=num_values, y=horizon), x=zero, y=(num_values - horizon)) horizon_estimate = horizon_estimate[:horizon_end] # Expand missing estimates with zeros terminal_size = tf.minimum(x=horizon, y=num_values) terminal_estimate = tf.zeros(shape=(terminal_size, ), dtype=util.tf_dtype(dtype='float')) horizon_estimate = tf.concat(values=(horizon_estimate, terminal_estimate), axis=0) else: # Expand rewards beyond terminal terminal_zeros = tf.zeros(shape=(horizon, ), dtype=util.tf_dtype(dtype='float')) rewards = tf.concat(values=(reward, terminal_zeros), axis=0) # Zero estimate horizon_estimate = tf.zeros(shape=(num_values, ), dtype=util.tf_dtype(dtype='float')) # Calculate discounted sum def cond(discounted_sum, horizon): return tf.math.greater_equal(x=horizon, y=zero) def body(discounted_sum, horizon): # discounted_sum = tf.compat.v1.Print( # discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10 # ) discounted_sum = discount * discounted_sum discounted_sum = discounted_sum + rewards[horizon:horizon + num_values] return discounted_sum, horizon - one values['reward'], _ = self.while_loop(cond=cond, body=body, loop_vars=(horizon_estimate, horizon), back_prop=False) return values