def tf_retrieve(self, indices, values): # if values is None: # is_single_value = False # values = ('states', 'internals', 'actions', 'terminal', 'reward') if isinstance(values, str): is_single_value = True values = [values] else: is_single_value = False values = list(values) # Retrieve values for n, name in enumerate(values): if util.is_nested(name=name): value = OrderedDict() for inner_name in self.values_spec[name]: value[inner_name] = tf.gather( params=self.buffers[name][inner_name], indices=indices ) else: value = tf.gather(params=self.buffers[name], indices=indices) values[n] = value # # Stop gradients # values = util.fmap(function=tf.stop_gradient, xs=values) # Return values or single value if is_single_value: return values[0] else: return values
def tf_initialize(self): super().tf_initialize() # Value buffers self.buffers = OrderedDict() for name, spec in self.values_spec.items(): if util.is_nested(name=name): self.buffers[name] = OrderedDict() for inner_name, spec in spec.items(): shape = (self.capacity, ) + spec['shape'] initializer = self.initializers.get(inner_name, 'zeros') self.buffers[name][inner_name] = self.add_variable( name=(inner_name + '-buffer'), dtype=spec['type'], shape=shape, is_trainable=False, initializer=initializer) else: shape = (self.capacity, ) + spec['shape'] initializer = self.initializers.get(name, 'zeros') self.buffers[name] = self.add_variable(name=(name + '-buffer'), dtype=spec['type'], shape=shape, is_trainable=False, initializer=initializer) # Buffer index (modulo capacity, next index to write to) self.buffer_index = self.add_variable(name='buffer-index', dtype='long', shape=(), is_trainable=False, initializer='zeros')
def tf_initialize(self): super().tf_initialize() # Value buffers self.buffers = OrderedDict() for name, spec in self.values_spec.items(): if util.is_nested(name=name): self.buffers[name] = OrderedDict() for inner_name, spec in spec.items(): shape = (self.capacity, ) + spec['shape'] self.buffers[name][inner_name] = self.add_variable( name=(inner_name + '-buffer'), dtype=spec['type'], shape=shape, is_trainable=False) else: shape = (self.capacity, ) + spec['shape'] if name == 'terminal': # Terminal initialization has to agree with terminal_indices initializer = np.zeros(shape=(self.capacity, ), dtype=util.np_dtype(dtype='long')) initializer[-1] = 1 self.buffers[name] = self.add_variable( name=(name + '-buffer'), dtype=spec['type'], shape=shape, is_trainable=False, initializer=initializer) else: self.buffers[name] = self.add_variable(name=(name + '-buffer'), dtype=spec['type'], shape=shape, is_trainable=False) # Buffer index (modulo capacity, next index to write to) self.buffer_index = self.add_variable(name='buffer-index', dtype='long', shape=(), is_trainable=False, initializer='zeros') # Terminal indices # (oldest episode terminals first, initially the only terminal is last index) initializer = np.zeros(shape=(self.capacity + 1, ), dtype=util.np_dtype(dtype='long')) initializer[0] = self.capacity - 1 self.terminal_indices = self.add_variable(name='terminal-indices', dtype='long', shape=(self.capacity + 1, ), is_trainable=False, initializer=initializer) # Episode count self.episode_count = self.add_variable(name='episode-count', dtype='long', shape=(), is_trainable=False, initializer='zeros')
def true_fn(): reset_values = self.estimator.reset(baseline=self.baseline_policy) new_overwritten_values = OrderedDict() for name, value1, value2 in util.zip_items(overwritten_values, reset_values): if util.is_nested(name=name): new_overwritten_values[name] = OrderedDict() for inner_name, value1, value2 in util.zip_items(value1, value2): new_overwritten_values[name][inner_name] = tf.concat( values=(value1, value2), axis=0 ) else: new_overwritten_values[name] = tf.concat(values=(value1, value2), axis=0) return new_overwritten_values
def tf_reset(self): # Constants zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) if not self.return_overwritten: # Reset buffer index assignment = self.buffer_index.assign(value=zero, read_value=False) # Return no-op with tf.control_dependencies(control_inputs=(assignment, )): return util.no_operation() # Overwritten buffer indices num_values = tf.minimum(x=self.buffer_index, y=capacity) indices = tf.range(start=(self.buffer_index - num_values), limit=self.buffer_index) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): values[name] = OrderedDict() for inner_name, buffer in buffer.items(): values[name][inner_name] = tf.gather(params=buffer, indices=indices) else: values[name] = tf.gather(params=buffer, indices=indices) # Reset buffer index with tf.control_dependencies(control_inputs=util.flatten(xs=values)): assignment = self.buffer_index.assign(value=zero, read_value=False) # Return overwritten values with tf.control_dependencies(control_inputs=(assignment, )): return util.fmap(function=util.identity_operation, xs=values)
def tf_successors(self, indices, horizon, sequence_values=(), final_values=()): if sequence_values == () and final_values == (): raise TensorforceError.unexpected() if isinstance(sequence_values, str): is_single_sequence_value = True sequence_values = [sequence_values] else: is_single_sequence_value = False sequence_values = list(sequence_values) if isinstance(final_values, str): is_single_final_value = True final_values = [final_values] else: is_single_final_value = False final_values = list(final_values) zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) def body(lengths, successor_indices, mask): current_index = successor_indices[:, -1:] current_terminal = self.retrieve(indices=current_index, values='terminal') is_not_terminal = tf.math.logical_and( x=tf.math.logical_not(x=tf.math.greater(x=current_terminal, y=zero)), y=mask[:, -1:] ) next_index = tf.math.mod(x=(current_index + one), y=capacity) successor_indices = tf.concat(values=(successor_indices, next_index), axis=1) mask = tf.concat(values=(mask, is_not_terminal), axis=1) is_not_terminal = tf.squeeze(input=is_not_terminal, axis=1) zeros = tf.zeros_like(tensor=is_not_terminal, dtype=util.tf_dtype(dtype='long')) ones = tf.ones_like(tensor=is_not_terminal, dtype=util.tf_dtype(dtype='long')) lengths += tf.where(condition=is_not_terminal, x=ones, y=zeros) return lengths, successor_indices, mask lengths = tf.ones_like(tensor=indices, dtype=util.tf_dtype(dtype='long')) successor_indices = tf.expand_dims(input=indices, axis=1) mask = tf.ones_like(tensor=successor_indices, dtype=util.tf_dtype(dtype='bool')) shape = tf.TensorShape(dims=((None, None))) lengths, successor_indices, mask = self.while_loop( cond=util.tf_always_true, body=body, loop_vars=(lengths, successor_indices, mask), shape_invariants=(lengths.get_shape(), shape, shape), back_prop=False, maximum_iterations=horizon ) successor_indices = tf.reshape(tensor=successor_indices, shape=(-1,)) mask = tf.reshape(tensor=mask, shape=(-1,)) successor_indices = tf.boolean_mask(tensor=successor_indices, mask=mask, axis=0) assertion = tf.compat.v1.debugging.assert_greater_equal( x=tf.math.mod(x=(self.buffer_index - one - successor_indices), y=capacity), y=zero ) with tf.control_dependencies(control_inputs=(assertion,)): starts = tf.math.cumsum(x=lengths, exclusive=True) ends = tf.math.cumsum(x=lengths) - one final_indices = tf.gather(params=successor_indices, indices=ends) for n, name in enumerate(sequence_values): if util.is_nested(name=name): sequence_value = OrderedDict() for inner_name, spec in self.values_spec[name].items(): sequence_value[inner_name] = tf.gather( params=self.buffers[name][inner_name], indices=successor_indices ) else: sequence_value = tf.gather( params=self.buffers[name], indices=successor_indices ) sequence_values[n] = sequence_value for n, name in enumerate(final_values): if util.is_nested(name=name): final_value = OrderedDict() for inner_name, spec in self.values_spec[name].items(): final_value[inner_name] = tf.gather( params=self.buffers[name][inner_name], indices=final_indices ) else: final_value = tf.gather( params=self.buffers[name], indices=final_indices ) final_values[n] = final_value # def body(lengths, sequence_values, final_values): # # Retrieve next indices # next_indices = tf.math.mod(x=(indices - lengths), y=capacity) # next_values = self.retrieve( # indices=next_indices, values=(tuple(sequence_values) + tuple(final_values)) # ) # # Overwrite final values # for name in final_values: # final_values[name] = next_values[name] # # Concatenate sequence values # for name, value, next_value in util.zip_items(sequence_values, next_values): # if util.is_nested(name=name): # for inner_name, value, next_value in util.zip_items(value, next_value): # next_value = tf.expand_dims(input=next_value, axis=1) # sequence_values[name][inner_name] = tf.concat( # values=(value, next_value), axis=1 # ) # else: # next_value = tf.expand_dims(input=next_value, axis=1) # sequence_values[name] = tf.concat(values=(value, next_value), axis=1) # # Increment lengths unless start of episode # with tf.control_dependencies(control_inputs=util.flatten(xs=next_values)): # next_indices = tf.math.mod(x=(next_indices - one), y=capacity) # terminal = self.retrieve(indices=next_indices, values='terminal') # x = tf.zeros_like(tensor=terminal, dtype=util.tf_dtype(dtype='long')) # y = tf.ones_like(tensor=terminal, dtype=util.tf_dtype(dtype='long')) # lengths += tf.where(condition=terminal, x=x, y=y) # return lengths, sequence_values, final_values # # Sequence lengths # lengths = tf.zeros_like(tensor=indices, dtype=util.tf_dtype(dtype='long')) # # Shape invariants # start_sequence_values = OrderedDict() # sequence_shapes = OrderedDict() # for name in sequence_values: # if util.is_nested(name=name): # start_sequence_values[name] = OrderedDict() # sequence_shapes[name] = OrderedDict() # for inner_name, spec in self.values_spec[name].items(): # start_sequence_values[name][inner_name] = tf.zeros(shape=((0, tf.shape(indices)[0]) + spec['shape'])) # shape = tf.TensorShape(dims=((None, None) + spec['shape'])) # sequence_shapes[name][inner_name] = shape # else: # start_sequence_values[name] = tf.zeros(shape=((0, tf.shape(indices)[0]) + self.values_spec[name]['shape'])) # shape = tf.TensorShape(dims=((None, None) + self.values_spec[name]['shape'])) # sequence_shapes[name] = shape # start_final_values = OrderedDict() # final_shapes = OrderedDict() # for name in final_values: # if util.is_nested(name=name): # start_final_values[name] = OrderedDict() # final_shapes[name] = OrderedDict() # for inner_name, spec in self.values_spec[name].items(): # start_final_values[name][inner_name] = tf.zeros(shape=((tf.shape(indices)[0],) + spec['shape'])) # shape = tf.TensorShape(dims=((None,) + spec['shape'])) # final_shapes[name][inner_name] = shape # else: # start_final_values[name] = tf.zeros(shape=((tf.shape(indices)[0],) + self.values_spec[name]['shape'])) # shape = tf.TensorShape(dims=((None,) + self.values_spec[name]['shape'])) # final_shapes[name] = shape # # Retrieve predecessors # lengths, sequence_values, final_values = self.while_loop( # cond=util.tf_always_true, body=body, # loop_vars=(lengths, start_sequence_values, start_final_values), # shape_invariants=(lengths.get_shape(), sequence_shapes, final_shapes), # back_prop=False, maximum_iterations=horizon # ) # # Stop gradients # sequence_values = util.fmap(function=tf.stop_gradient, xs=sequence_values) # final_values = util.fmap(function=tf.stop_gradient, xs=final_values) if len(sequence_values) == 0: if is_single_final_value: final_values = final_values[0] return lengths, final_values elif len(final_values) == 0: if is_single_sequence_value: sequence_values = sequence_values[0] return starts, lengths, sequence_values else: if is_single_sequence_value: sequence_values = sequence_values[0] if is_single_final_value: final_values = final_values[0] return starts, lengths, sequence_values, final_values
def tf_enqueue(self, **values): # Constants zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) # Get number of values for value in values.values(): if not isinstance(value, dict): break elif len(value) > 0: value = next(iter(value.values())) break if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_values = tf.shape(input=value, out_type=util.tf_dtype(dtype='long'))[0] else: num_values = tf.dtypes.cast(x=tf.shape(input=value)[0], dtype=util.tf_dtype(dtype='long')) # Check whether instances fit into buffer assertion = tf.debugging.assert_less_equal(x=num_values, y=capacity) if self.return_overwritten: # Overwritten buffer indices with tf.control_dependencies(control_inputs=(assertion, )): start = tf.maximum(x=self.buffer_index, y=capacity) limit = tf.maximum(x=(self.buffer_index + num_values), y=capacity) num_overwritten = limit - start indices = tf.range(start=start, limit=limit) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values with tf.control_dependencies(control_inputs=(indices, )): overwritten_values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): overwritten_values[name] = OrderedDict() for inner_name, buffer in buffer.items(): overwritten_values[name][inner_name] = tf.gather( params=buffer, indices=indices) else: overwritten_values[name] = tf.gather(params=buffer, indices=indices) else: overwritten_values = (assertion, ) # Buffer indices to (over)write with tf.control_dependencies(control_inputs=util.flatten( xs=overwritten_values)): indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_values)) indices = tf.math.mod(x=indices, y=capacity) indices = tf.expand_dims(input=indices, axis=1) # Write new values with tf.control_dependencies(control_inputs=(indices, )): assignments = list() for name, buffer in self.buffers.items(): if util.is_nested(name=name): for inner_name, buffer in buffer.items(): assignment = buffer.scatter_nd_update( indices=indices, updates=values[name][inner_name]) assignments.append(assignment) else: assignment = buffer.scatter_nd_update(indices=indices, updates=values[name]) assignments.append(assignment) # Increment buffer index with tf.control_dependencies(control_inputs=assignments): assignment = self.buffer_index.assign_add(delta=num_values, read_value=False) # Return overwritten values or no-op with tf.control_dependencies(control_inputs=(assignment, )): if self.return_overwritten: any_overwritten = tf.math.greater(x=num_overwritten, y=zero) overwritten_values = util.fmap( function=util.identity_operation, xs=overwritten_values) return any_overwritten, overwritten_values else: return util.no_operation()
def tf_enqueue(self, states, internals, auxiliaries, actions, terminal, reward, baseline=None): # Constants and parameters zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) horizon = self.horizon.value() discount = self.discount.value() assertions = list() # Check whether horizon at most capacity assertions.append( tf.debugging.assert_less_equal( x=horizon, y=capacity, message= "Estimator capacity has to be at least the same as the estimation horizon." )) # Check whether at most one terminal assertions.append( tf.debugging.assert_less_equal( x=tf.math.count_nonzero(input=terminal, dtype=util.tf_dtype(dtype='long')), y=one, message="Timesteps contain more than one terminal.")) # Check whether, if any, last value is terminal assertions.append( tf.debugging.assert_equal( x=tf.reduce_any( input_tensor=tf.math.greater(x=terminal, y=zero)), y=tf.math.greater(x=terminal[-1], y=zero), message="Terminal is not the last timestep.")) # Get number of overwritten values with tf.control_dependencies(control_inputs=assertions): if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_values = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0] else: num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0], dtype=util.tf_dtype(dtype='long')) overwritten_start = tf.maximum(x=self.buffer_index, y=capacity) overwritten_limit = tf.maximum(x=(self.buffer_index + num_values), y=capacity) num_overwritten = overwritten_limit - overwritten_start def update_overwritten_rewards(): # Get relevant buffer rewards buffer_limit = self.buffer_index + tf.minimum( x=(num_overwritten + horizon), y=capacity) buffer_indices = tf.range(start=self.buffer_index, limit=buffer_limit) buffer_indices = tf.math.mod(x=buffer_indices, y=capacity) rewards = tf.gather(params=self.buffers['reward'], indices=buffer_indices) # Get relevant values rewards values_limit = tf.maximum(x=(num_overwritten + horizon - capacity), y=zero) rewards = tf.concat(values=(rewards, reward[:values_limit]), axis=0) # Horizon baseline value if self.estimate_horizon == 'early': assert baseline is not None # Baseline estimate buffer_indices = buffer_indices[horizon + one:] _states = OrderedDict() for name, buffer in self.buffers['states'].items(): state = tf.gather(params=buffer, indices=buffer_indices) _states[name] = tf.concat( values=(state, states[name][:values_limit + one]), axis=0) _internals = OrderedDict() for name, buffer in self.buffers['internals'].items(): internal = tf.gather(params=buffer, indices=buffer_indices) _internals[name] = tf.concat( values=(internal, internals[name][:values_limit + one]), axis=0) _auxiliaries = OrderedDict() for name, buffer in self.buffers['auxiliaries'].items(): auxiliary = tf.gather(params=buffer, indices=buffer_indices) _auxiliaries[name] = tf.concat( values=(auxiliary, auxiliaries[name][:values_limit + one]), axis=0) # Dependency horizon # TODO: handle arbitrary non-optimization horizons! past_horizon = baseline.past_horizon(is_optimization=False) assertion = tf.debugging.assert_equal( x=past_horizon, y=zero, message= "Temporary: baseline cannot depend on previous states.") with tf.control_dependencies(control_inputs=(assertion, )): some_state = next(iter(_states.values())) if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): batch_size = tf.shape( input=some_state, out_type=util.tf_dtype(dtype='long'))[0] else: batch_size = tf.dtypes.cast( x=tf.shape(input=some_state)[0], dtype=util.tf_dtype(dtype='long')) starts = tf.range(start=batch_size, dtype=util.tf_dtype(dtype='long')) lengths = tf.ones(shape=(batch_size, ), dtype=util.tf_dtype(dtype='long')) Module.update_tensors(dependency_starts=starts, dependency_lengths=lengths) if self.estimate_actions: _actions = OrderedDict() for name, buffer in self.buffers['actions'].items(): action = tf.gather(params=buffer, indices=buffer_indices) _actions[name] = tf.concat( values=(action, actions[name][:values_limit]), axis=0) horizon_estimate = baseline.actions_value( states=_states, internals=_internals, auxiliaries=_auxiliaries, actions=_actions) else: horizon_estimate = baseline.states_value( states=_states, internals=_internals, auxiliaries=_auxiliaries) else: # Zero estimate horizon_estimate = tf.zeros(shape=(num_overwritten, ), dtype=util.tf_dtype(dtype='float')) # Calculate discounted sum def cond(discounted_sum, horizon): return tf.math.greater_equal(x=horizon, y=zero) def body(discounted_sum, horizon): # discounted_sum = tf.compat.v1.Print( # discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10 # ) discounted_sum = discount * discounted_sum discounted_sum = discounted_sum + rewards[horizon:horizon + num_overwritten] return discounted_sum, horizon - one discounted_sum, _ = self.while_loop(cond=cond, body=body, loop_vars=(horizon_estimate, horizon), back_prop=False) assertions = [ tf.debugging.assert_equal(x=tf.shape(input=horizon_estimate), y=tf.shape(input=discounted_sum), message="Estimation check."), tf.debugging.assert_equal(x=tf.shape( input=rewards, out_type=util.tf_dtype(dtype='long'))[0], y=(horizon + num_overwritten), message="Estimation check.") ] # Overwrite buffer rewards with tf.control_dependencies(control_inputs=assertions): indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_overwritten)) indices = tf.math.mod(x=indices, y=capacity) indices = tf.expand_dims(input=indices, axis=1) assignment = self.buffers['reward'].scatter_nd_update( indices=indices, updates=discounted_sum) with tf.control_dependencies(control_inputs=(assignment, )): return util.no_operation() any_overwritten = tf.math.greater(x=num_overwritten, y=zero) updated_rewards = self.cond(pred=any_overwritten, true_fn=update_overwritten_rewards, false_fn=util.no_operation) # Overwritten buffer indices with tf.control_dependencies(control_inputs=(updated_rewards, )): indices = tf.range(start=overwritten_start, limit=overwritten_limit) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values with tf.control_dependencies(control_inputs=(indices, )): overwritten_values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): overwritten_values[name] = OrderedDict() for inner_name, buffer in buffer.items(): overwritten_values[name][inner_name] = tf.gather( params=buffer, indices=indices) else: overwritten_values[name] = tf.gather(params=buffer, indices=indices) # Buffer indices to (over)write with tf.control_dependencies(control_inputs=util.flatten( xs=overwritten_values)): indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_values)) indices = tf.math.mod(x=indices, y=capacity) indices = tf.expand_dims(input=indices, axis=1) # Write new values with tf.control_dependencies(control_inputs=(indices, )): values = dict(states=states, internals=internals, auxiliaries=auxiliaries, actions=actions, terminal=terminal, reward=reward) assignments = list() for name, buffer in self.buffers.items(): if util.is_nested(name=name): for inner_name, buffer in buffer.items(): assignment = buffer.scatter_nd_update( indices=indices, updates=values[name][inner_name]) assignments.append(assignment) else: assignment = buffer.scatter_nd_update(indices=indices, updates=values[name]) assignments.append(assignment) # Increment buffer index with tf.control_dependencies(control_inputs=assignments): assignment = self.buffer_index.assign_add(delta=num_values, read_value=False) # Return overwritten values or no-op with tf.control_dependencies(control_inputs=(assignment, )): any_overwritten = tf.math.greater(x=num_overwritten, y=zero) overwritten_values = util.fmap(function=util.identity_operation, xs=overwritten_values) return any_overwritten, overwritten_values
def tf_reset(self, baseline=None): # Constants and parameters zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) horizon = self.horizon.value() discount = self.discount.value() # Overwritten buffer indices num_overwritten = tf.minimum(x=self.buffer_index, y=capacity) indices = tf.range(start=(self.buffer_index - num_overwritten), limit=self.buffer_index) indices = tf.math.mod(x=indices, y=capacity) # Get overwritten values values = OrderedDict() for name, buffer in self.buffers.items(): if util.is_nested(name=name): values[name] = OrderedDict() for inner_name, buffer in buffer.items(): values[name][inner_name] = tf.gather(params=buffer, indices=indices) else: values[name] = tf.gather(params=buffer, indices=indices) states = values['states'] internals = values['internals'] auxiliaries = values['auxiliaries'] actions = values['actions'] terminal = values['terminal'] reward = values['reward'] terminal = values['terminal'] # Reset buffer index with tf.control_dependencies(control_inputs=util.flatten(xs=values)): assignment = self.buffer_index.assign(value=zero, read_value=False) with tf.control_dependencies(control_inputs=(assignment, )): assertions = list() # Check whether exactly one terminal (, unless empty?) assertions.append( tf.debugging.assert_equal( x=tf.math.count_nonzero(input=terminal, dtype=util.tf_dtype(dtype='long')), y=one, message="Timesteps do not contain exactly one terminal.")) # Check whether last value is terminal assertions.append( tf.debugging.assert_equal( x=tf.math.greater(x=terminal[-1], y=zero), y=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool')), message="Terminal is not the last timestep.")) # Get number of values with tf.control_dependencies(control_inputs=assertions): if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_values = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0] else: num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0], dtype=util.tf_dtype(dtype='long')) # Horizon baseline value if self.estimate_horizon == 'early' and baseline is not None: # Dependency horizon # TODO: handle arbitrary non-optimization horizons! past_horizon = baseline.past_horizon(is_optimization=False) assertion = tf.debugging.assert_equal( x=past_horizon, y=zero, message="Temporary: baseline cannot depend on previous states." ) # Baseline estimate horizon_start = num_values - tf.maximum(x=(num_values - horizon), y=one) _states = OrderedDict() for name, state in states.items(): _states[name] = state[horizon_start:] _internals = OrderedDict() for name, internal in internals.items(): _internals[name] = internal[horizon_start:] _auxiliaries = OrderedDict() for name, auxiliary in auxiliaries.items(): _auxiliaries[name] = auxiliary[horizon_start:] with tf.control_dependencies(control_inputs=(assertion, )): # some_state = next(iter(states.values())) # if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): # batch_size = tf.shape(input=some_state, out_type=util.tf_dtype(dtype='long'))[0] # else: # batch_size = tf.dtypes.cast( # x=tf.shape(input=some_state)[0], dtype=util.tf_dtype(dtype='long') # ) batch_size = num_values - horizon_start starts = tf.range(start=batch_size, dtype=util.tf_dtype(dtype='long')) lengths = tf.ones(shape=(batch_size, ), dtype=util.tf_dtype(dtype='long')) Module.update_tensors(dependency_starts=starts, dependency_lengths=lengths) if self.estimate_actions: _actions = OrderedDict() for name, action in actions.items(): _actions[name] = action[horizon_start:] horizon_estimate = baseline.actions_value( states=_states, internals=_internals, auxiliaries=_auxiliaries, actions=_actions) else: horizon_estimate = baseline.states_value( states=_states, internals=_internals, auxiliaries=_auxiliaries) # Expand rewards beyond terminal terminal_zeros = tf.zeros(shape=(horizon, ), dtype=util.tf_dtype(dtype='float')) if self.estimate_terminal: rewards = tf.concat(values=(reward[:-1], horizon_estimate[-1:], terminal_zeros), axis=0) else: with tf.control_dependencies(control_inputs=(assertion, )): last_reward = tf.where(condition=tf.math.greater( x=terminal[-1], y=one), x=horizon_estimate[-1], y=reward[-1]) rewards = tf.concat(values=(reward[:-1], (last_reward, ), terminal_zeros), axis=0) # Remove last if necessary horizon_end = tf.where(condition=tf.math.less_equal(x=num_values, y=horizon), x=zero, y=(num_values - horizon)) horizon_estimate = horizon_estimate[:horizon_end] # Expand missing estimates with zeros terminal_size = tf.minimum(x=horizon, y=num_values) terminal_estimate = tf.zeros(shape=(terminal_size, ), dtype=util.tf_dtype(dtype='float')) horizon_estimate = tf.concat(values=(horizon_estimate, terminal_estimate), axis=0) else: # Expand rewards beyond terminal terminal_zeros = tf.zeros(shape=(horizon, ), dtype=util.tf_dtype(dtype='float')) rewards = tf.concat(values=(reward, terminal_zeros), axis=0) # Zero estimate horizon_estimate = tf.zeros(shape=(num_values, ), dtype=util.tf_dtype(dtype='float')) # Calculate discounted sum def cond(discounted_sum, horizon): return tf.math.greater_equal(x=horizon, y=zero) def body(discounted_sum, horizon): # discounted_sum = tf.compat.v1.Print( # discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10 # ) discounted_sum = discount * discounted_sum discounted_sum = discounted_sum + rewards[horizon:horizon + num_values] return discounted_sum, horizon - one values['reward'], _ = self.while_loop(cond=cond, body=body, loop_vars=(horizon_estimate, horizon), back_prop=False) return values
def tf_enqueue(self, states, internals, auxiliaries, actions, terminal, reward): zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long')) one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long')) three = tf.constant(value=3, dtype=util.tf_dtype(dtype='long')) capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long')) if util.tf_dtype(dtype='long') in (tf.int32, tf.int64): num_timesteps = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0] else: num_timesteps = tf.dtypes.cast(x=tf.shape(input=terminal)[0], dtype=util.tf_dtype(dtype='long')) # # Max capacity # latest_terminal_index = self.terminal_indices[self.episode_count] # max_capacity = self.buffer_index - latest_terminal_index - one # max_capacity = capacity - (tf.math.mod(x=max_capacity, y=capacity) + one) # Remove last observation terminal marker last_index = tf.math.mod(x=(self.buffer_index - one), y=capacity) last_terminal = tf.gather(params=self.buffers['terminal'], indices=(last_index, ))[0] corrected_terminal = tf.where(condition=tf.math.equal(x=last_terminal, y=three), x=zero, y=last_terminal) assignment = tf.compat.v1.assign( ref=self.buffers['terminal'][last_index], value=corrected_terminal) # Assertions with tf.control_dependencies(control_inputs=(assignment, )): assertions = [ # check: number of timesteps fit into effectively available buffer tf.debugging.assert_less_equal( x=num_timesteps, y=capacity, message="Memory does not have enough capacity."), # at most one terminal tf.debugging.assert_less_equal( x=tf.math.count_nonzero(input=terminal, dtype=util.tf_dtype(dtype='long')), y=one, message="Timesteps contain more than one terminal."), # if terminal, last timestep in batch tf.debugging.assert_equal( x=tf.math.reduce_any( input_tensor=tf.math.greater(x=terminal, y=zero)), y=tf.math.greater(x=terminal[-1], y=zero), message="Terminal is not the last timestep."), # general check: all terminal indices true tf.debugging. assert_equal(x=tf.reduce_all(input_tensor=tf.gather( params=tf.math.greater(x=self.buffers['terminal'], y=zero), indices=self.terminal_indices[:self.episode_count + one])), y=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool')), message="Memory consistency check."), # general check: only terminal indices true tf.debugging.assert_equal(x=tf.math.count_nonzero( input=self.buffers['terminal'], dtype=util.tf_dtype(dtype='long')), y=(self.episode_count + one), message="Memory consistency check.") ] # Buffer indices to overwrite with tf.control_dependencies(control_inputs=assertions): overwritten_indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_timesteps)) overwritten_indices = tf.math.mod(x=overwritten_indices, y=capacity) # Count number of overwritten episodes num_episodes = tf.math.count_nonzero( input=tf.gather(params=self.buffers['terminal'], indices=overwritten_indices), axis=0, dtype=util.tf_dtype(dtype='long')) # Shift remaining terminal indices accordingly limit_index = self.episode_count + one assertion = tf.debugging.assert_greater_equal( x=limit_index, y=num_episodes, message="Memory episode overwriting check.") with tf.control_dependencies(control_inputs=(assertion, )): assignment = tf.compat.v1.assign( ref=self.terminal_indices[:limit_index - num_episodes], value=self.terminal_indices[num_episodes:limit_index]) # Decrement episode count accordingly with tf.control_dependencies(control_inputs=(assignment, )): assignment = self.episode_count.assign_sub(delta=num_episodes, read_value=False) # Write new observations with tf.control_dependencies(control_inputs=(assignment, )): indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_timesteps)) indices = tf.math.mod(x=indices, y=capacity) indices = tf.expand_dims(input=indices, axis=1) values = dict(states=states, internals=internals, auxiliaries=auxiliaries, actions=actions, terminal=terminal, reward=reward) assignments = list() for name, buffer in self.buffers.items(): if util.is_nested(name=name): for inner_name, buffer in buffer.items(): assignment = buffer.scatter_nd_update( indices=indices, updates=values[name][inner_name]) assignments.append(assignment) else: if name == 'terminal': # Add last observation terminal marker corrected_terminal = tf.where(condition=tf.math.equal( x=terminal[-1], y=zero), x=three, y=terminal[-1]) assignment = buffer.scatter_nd_update( indices=indices, updates=tf.concat(values=(terminal[:-1], (corrected_terminal, )), axis=0)) else: assignment = buffer.scatter_nd_update( indices=indices, updates=values[name]) assignments.append(assignment) # Increment buffer index with tf.control_dependencies(control_inputs=assignments): assignment = self.buffer_index.assign_add(delta=num_timesteps, read_value=False) # Count number of new episodes with tf.control_dependencies(control_inputs=(assignment, )): num_new_episodes = tf.math.count_nonzero( input=terminal, dtype=util.tf_dtype(dtype='long')) # Write new terminal indices limit_index = self.episode_count + one assignment = tf.compat.v1.assign( ref=self.terminal_indices[limit_index:limit_index + num_new_episodes], value=tf.boolean_mask(tensor=overwritten_indices, mask=tf.math.greater(x=terminal, y=zero))) # Increment episode count accordingly with tf.control_dependencies(control_inputs=(assignment, )): assignment = self.episode_count.assign_add(delta=num_new_episodes, read_value=False) with tf.control_dependencies(control_inputs=(assignment, )): return util.no_operation()