Example #1
0
    def tf_optimization(
        self, states, internals, actions, terminal, reward, next_states=None, next_internals=None
    ):
        """
        Creates the TensorFlow operations for performing an optimization update step based
        on the given input states and actions batch.

        Args:
            states: Dict of state tensors.
            internals: List of prior internal state tensors.
            actions: Dict of action tensors.
            terminal: Terminal boolean tensor.
            reward: Reward tensor.
            next_states: Dict of successor state tensors.
            next_internals: List of posterior internal state tensors.

        Returns:
            The optimization operation.
        """
        parameters_before = OrderedDict()
        embedding = self.network.apply(x=states, internals=internals)
        for name, distribution in self.distributions.items():
            parameters_before[name] = distribution.parametrize(x=embedding)

        with tf.control_dependencies(control_inputs=util.flatten(xs=parameters_before)):
            optimized = super().tf_optimization(
                states=states, internals=internals, actions=actions, terminal=terminal,
                reward=reward, next_states=next_states, next_internals=next_internals
            )

        with tf.control_dependencies(control_inputs=(optimized,)):
            summaries = list()
            embedding = self.network.apply(x=states, internals=internals)
            for name, distribution in self.distributions.items():
                parameters = distribution.parametrize(x=embedding)
                kl_divergence = distribution.kl_divergence(
                    parameters1=parameters_before[name], parameters2=parameters
                )
                collapsed_size = util.product(xs=util.shape(kl_divergence)[1:])
                kl_divergence = tf.reshape(tensor=kl_divergence, shape=(-1, collapsed_size))
                kl_divergence = tf.reduce_mean(input_tensor=kl_divergence, axis=1)
                kl_divergence = self.add_summary(
                    label='kl-divergence', name=(name + '-kldiv'), tensor=kl_divergence
                )
                summaries.append(kl_divergence)

                entropy = distribution.entropy(parameters=parameters)
                entropy = tf.reshape(tensor=entropy, shape=(-1, collapsed_size))
                entropy = tf.reduce_mean(input_tensor=entropy, axis=1)
                entropy = self.add_summary(
                    label='entropy', name=(name + '-entropy'), tensor=entropy
                )
                summaries.append(entropy)

        with tf.control_dependencies(control_inputs=summaries):
            return util.no_operation()
Example #2
0
    def tf_apply_step(self, variables, deltas):
        if len(variables) != len(deltas):
            raise TensorforceError("Invalid variables and deltas lists.")

        assignments = list()
        for variable, delta in zip(variables, deltas):
            assignments.append(tf.assign_add(ref=variable, value=delta))

        with tf.control_dependencies(control_inputs=assignments):
            return util.no_operation()
Example #3
0
    def tf_minimize(self, variables, **kwargs):
        if any(variable.dtype != util.tf_dtype(dtype='float')
               for variable in variables):
            raise TensorforceError.unexpected()

        deltas = self.step(variables=variables, **kwargs)

        update_norm = tf.linalg.global_norm(t_list=deltas)
        deltas = self.add_summary(label='update-norm',
                                  name='update-norm',
                                  tensor=update_norm,
                                  pass_tensors=deltas)

        for n in range(len(variables)):
            name = variables[n].name
            if name[-2:] != ':0':
                raise TensorforceError.unexpected()
            deltas[n] = self.add_summary(label='updates',
                                         name=('update-' + name[:-2]),
                                         tensor=deltas[n],
                                         mean_variance=True)
            deltas[n] = self.add_summary(label='updates-histogram',
                                         name=('update-' + name[:-2]),
                                         tensor=deltas[n])

        # TODO: experimental
        # with tf.control_dependencies(control_inputs=deltas):
        #     zero = tf.constant(value=0.0, dtype=util.tf_dtype(dtype='float'))
        #     false = tf.constant(value=False, dtype=util.tf_dtype(dtype='bool'))
        #     deltas = [self.cond(
        #         pred=tf.math.reduce_all(input_tensor=tf.math.equal(x=delta, y=zero)),
        #         true_fn=(lambda: tf.Print(delta, (variable.name,))),
        #         false_fn=(lambda: delta)) for delta, variable in zip(deltas, variables)
        #     ]
        #     assertions = [
        #         tf.debugging.assert_equal(
        #             x=tf.math.reduce_all(input_tensor=tf.math.equal(x=delta, y=zero)), y=false,
        #             message="Zero delta check."
        #         ) for delta, variable in zip(deltas, variables)
        #         if util.product(xs=util.shape(x=delta)) > 4 and 'distribution' not in variable.name
        #     ]

        # with tf.control_dependencies(control_inputs=assertions):
        with tf.control_dependencies(control_inputs=deltas):
            return util.no_operation()
Example #4
0
    def tf_store(self, states, internals, actions, terminal, reward):
        # We first store new experiences into a buffer that is separate from main memory.
        # We insert these into the main memory once we have computed priorities on a given batch.
        num_instances = tf.shape(input=terminal)[0]

        # Simple way to prevent buffer overflows.
        start_index = self.cond(
            # Why + 1? Because of next state, otherwise that has to be handled separately.
            pred=(self.buffer_index + num_instances + 1 >= self.buffer_size),
            true_fn=(lambda: 0),
            false_fn=(lambda: self.buffer_index))
        end_index = start_index + num_instances

        # Assign new observations.
        assignments = list()
        for name in sorted(states):
            assignments.append(
                tf.assign(ref=self.states_buffer[name][start_index:end_index],
                          value=states[name]))
        for name in sorted(internals):
            assignments.append(
                tf.assign(
                    ref=self.internals_buffer[name][start_index:end_index],
                    value=internals[name]))
        for name in sorted(actions):
            assignments.append(
                tf.assign(ref=self.actions_buffer[name][start_index:end_index],
                          value=actions[name]))

        assignments.append(
            tf.assign(ref=self.terminal_buffer[start_index:end_index],
                      value=terminal))
        assignments.append(
            tf.assign(ref=self.reward_buffer[start_index:end_index],
                      value=reward))

        # Increment memory index.
        with tf.control_dependencies(control_inputs=assignments):
            assignment = tf.assign(ref=self.buffer_index,
                                   value=(start_index + num_instances))

        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.no_operation()
    def tf_apply_step(self, variables, deltas):
        """
        Applies the given (and already calculated) step deltas to the variable values.

        Args:
            variables: List of variables.
            deltas: List of deltas of same length.

        Returns:
            The step-applied operation. A tf.group of tf.assign_add ops.
        """
        if len(variables) != len(deltas):
            raise TensorforceError("Invalid variables and deltas lists.")

        assignments = list()
        for variable, delta in zip(variables, deltas):
            assignments.append(tf.assign_add(ref=variable, value=delta))

        with tf.control_dependencies(control_inputs=assignments):
            return util.no_operation()
    def tf_minimize(self, variables, **kwargs):
        """
        Performs an optimization step.

        Args:
            variables: List of variables to optimize.
            **kwargs: Additional optimizer-specific arguments. The following arguments are used
                by some optimizers:
            - arguments: Dict of arguments for callables, like fn_loss.
            - fn_loss: A callable returning the loss of the current model.
            - fn_reference: A callable returning the reference values, in case of a comparative  
                loss.
            - fn_kl_divergence: A callable returning the KL-divergence relative to the
                current model.
            - sampled_loss: A sampled loss (integer).
            - return_estimated_improvement: Returns the estimated improvement resulting from
                the natural gradient calculation if true.
            - source_variables: List of source variables to synchronize with.
            - global_variables: List of global variables to apply the proposed optimization
                step to.


        Returns:
            The optimization operation.
        """
        deltas = self.step(variables=variables, **kwargs)

        for n in range(len(variables)):
            name = variables[n].name
            if name[-2:] != ':0':
                raise TensorforceError.unexpected()
            deltas[n] = self.add_summary(label=('updates', 'updates-full'),
                                         name=(name[:-2] + '-update'),
                                         tensor=deltas[n],
                                         mean_variance=True)
            deltas[n] = self.add_summary(label='updates-full',
                                         name=(name[:-2] + '-update'),
                                         tensor=deltas[n])

        with tf.control_dependencies(control_inputs=deltas):
            return util.no_operation()
    def tf_optimization(self,
                        states,
                        internals,
                        actions,
                        terminal,
                        reward,
                        next_states=None,
                        next_internals=None):
        assert next_states is None and next_internals is None  # temporary

        estimated_reward = self.reward_estimation(states=states,
                                                  internals=internals,
                                                  terminal=terminal,
                                                  reward=reward)
        if self.baseline_optimizer is not None:
            estimated_reward = tf.stop_gradient(input=estimated_reward)

        optimization = super().tf_optimization(states=states,
                                               internals=internals,
                                               actions=actions,
                                               terminal=terminal,
                                               reward=estimated_reward,
                                               next_states=next_states,
                                               next_internals=next_internals)

        if self.baseline_optimizer is not None:
            cumulative_reward = self.discounted_cumulative_reward(
                terminal=terminal, reward=reward)

            arguments = self.baseline_optimizer_arguments(
                states=states, internals=internals, reward=cumulative_reward)
            baseline_optimization = self.baseline_optimizer.minimize(
                **arguments)

            with tf.control_dependencies(
                    control_inputs=(optimization, baseline_optimization)):
                optimization = util.no_operation()

        return optimization
    def tf_reset(self):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        if not self.return_overwritten:
            # Reset buffer index
            assignment = self.buffer_index.assign(value=zero, read_value=False)

            # Return no-op
            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()

        # Overwritten buffer indices
        num_values = tf.minimum(x=self.buffer_index, y=capacity)
        indices = tf.range(start=(self.buffer_index - num_values),
                           limit=self.buffer_index)
        indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        values = OrderedDict()
        for name, buffer in self.buffers.items():
            if util.is_nested(name=name):
                values[name] = OrderedDict()
                for inner_name, buffer in buffer.items():
                    values[name][inner_name] = tf.gather(params=buffer,
                                                         indices=indices)
            else:
                values[name] = tf.gather(params=buffer, indices=indices)

        # Reset buffer index
        with tf.control_dependencies(control_inputs=util.flatten(xs=values)):
            assignment = self.buffer_index.assign(value=zero, read_value=False)

        # Return overwritten values
        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.fmap(function=util.identity_operation, xs=values)
Example #9
0
    def tf_optimization(self,
                        states,
                        internals,
                        actions,
                        terminal,
                        reward,
                        next_states=None,
                        next_internals=None):
        optimization = super().tf_optimization(states=states,
                                               internals=internals,
                                               actions=actions,
                                               terminal=terminal,
                                               reward=reward,
                                               next_states=next_states,
                                               next_internals=next_internals)

        arguments = self.target_optimizer_arguments()
        target_optimization = self.target_optimizer.minimize(**arguments)

        with tf.control_dependencies(control_inputs=(optimization,
                                                     target_optimization)):
            optimization = util.no_operation()

        return optimization
Example #10
0
    def tf_enqueue(self, states, internals, auxiliaries, actions, terminal, reward):
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long'))
        num_timesteps = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0]

        # Max capacity
        latest_terminal_index = self.terminal_indices[self.episode_count]
        max_capacity = self.buffer_index - latest_terminal_index - one
        max_capacity = capacity - (tf.math.mod(x=max_capacity, y=capacity) + one)

        # Assertions
        assertions = [
            # check: number of timesteps fit into effectively available buffer
            tf.compat.v1.debugging.assert_less_equal(x=num_timesteps, y=max_capacity),
            # at most one terminal
            tf.compat.v1.debugging.assert_less_equal(
                x=tf.math.count_nonzero(input_tensor=terminal, dtype=util.tf_dtype(dtype='long')),
                y=one
            ),
            # if terminal, last timestep in batch
            tf.compat.v1.debugging.assert_equal(
                x=tf.math.reduce_any(input_tensor=tf.math.greater(x=terminal, y=zero)),
                y=tf.math.greater(x=terminal[-1], y=zero)
            ),
            # general check: all terminal indices true
            tf.compat.v1.debugging.assert_equal(
                x=tf.reduce_all(
                    input_tensor=tf.gather(
                        params=tf.math.greater(x=self.buffers['terminal'], y=zero),
                        indices=self.terminal_indices[:self.episode_count + one]
                    )
                ),
                y=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool'))
            ),
            # general check: only terminal indices true
            tf.compat.v1.debugging.assert_equal(
                x=tf.math.count_nonzero(
                    input_tensor=self.buffers['terminal'], dtype=util.tf_dtype(dtype='long')
                ),
                y=(self.episode_count + one)
            ),
            # # general check: no terminal after last
            # tf.compat.v1.debugging.assert_equal(
            #     x=tf.reduce_any(
            #         input_tensor=tf.gather(
            #             params=self.buffers['terminal'], indices=tf.mathmod(
            #                 x=(self.terminal_indices[self.episode_count] + one + tf.range(
            #                     start=tf.math.mod(
            #                         x=(self.buffer_index - self.terminal_indices[self.episode_count] - one),
            #                         y=capacity
            #                     )
            #                 )),
            #                 y=capacity
            #             )
            #         )
            #     ),
            #     y=tf.constant(value=False, dtype=util.tf_dtype(dtype='bool'))
            # )
        ]

        # Buffer indices to overwrite
        with tf.control_dependencies(control_inputs=assertions):
            indices = tf.range(start=self.buffer_index, limit=(self.buffer_index + num_timesteps))
            indices = tf.math.mod(x=indices, y=capacity)

            # Count number of overwritten episodes
            num_episodes = tf.math.count_nonzero(
                input_tensor=tf.gather(params=self.buffers['terminal'], indices=indices), axis=0,
                dtype=util.tf_dtype(dtype='long')
            )

        # Shift remaining terminal indices accordingly
        limit_index = self.episode_count + one
        assertion = tf.compat.v1.debugging.assert_greater_equal(x=limit_index, y=num_episodes)

        with tf.control_dependencies(control_inputs=(assertion,)):
            assignment = tf.compat.v1.assign(
                ref=self.terminal_indices[:limit_index - num_episodes],
                value=self.terminal_indices[num_episodes: limit_index]
            )

        # Decrement episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment,)):
            assignment = self.episode_count.assign_sub(delta=num_episodes, read_value=False)

        # # Write new observations
        with tf.control_dependencies(control_inputs=(assignment,)):
            assignment = super().tf_enqueue(
                states=states, internals=internals, auxiliaries=auxiliaries, actions=actions,
                terminal=terminal, reward=reward
            )

        # # Write new observations
        # with tf.control_dependencies(control_inputs=(assignment,)):
        #     assignments = list()
        #     for name, state in states.items():
        #         assignments.append(
        #             tf.scatter_update(ref=self.memories[name], indices=indices, updates=state)
        #         )
        #     for name, internal in internals.items():
        #         assignments.append(
        #             tf.scatter_update(ref=self.memories[name], indices=indices, updates=internal)
        #         )
        #     for name, action in actions.items():
        #         assignments.append(
        #             tf.scatter_update(ref=self.memories[name], indices=indices, updates=action)
        #         )
        #     assignments.append(
        #         tf.scatter_update(ref=self.memories['terminal'], indices=indices, updates=terminal)
        #     )
        #     assignments.append(
        #         tf.scatter_update(ref=self.memories['reward'], indices=indices, updates=reward)
        #     )

        # # Increment memory index
        # with tf.control_dependencies(control_inputs=assignments):
        #     new_memory_index = tf.math.mod(x=(self.memory_index + num_timesteps), y=capacity)
        #     assignment = self.memory_index.assign(value=new_memory_index)

        # Count number of new episodes
        with tf.control_dependencies(control_inputs=(assignment,)):
            num_new_episodes = tf.math.count_nonzero(
                input_tensor=terminal, dtype=util.tf_dtype(dtype='long')
            )

            # Write new terminal indices
            limit_index = self.episode_count + one
            assignment = tf.compat.v1.assign(
                ref=self.terminal_indices[limit_index: limit_index + num_new_episodes],
                value=tf.boolean_mask(tensor=indices, mask=tf.math.greater(x=terminal, y=zero))
            )

        # Increment episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment,)):
            assignment = self.episode_count.assign_add(delta=num_new_episodes, read_value=False)

        with tf.control_dependencies(control_inputs=(assignment,)):
            return util.no_operation()
    def tf_enqueue(self, **values):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        # Get number of values
        for value in values.values():
            if not isinstance(value, dict):
                break
            elif len(value) > 0:
                value = next(iter(value.values()))
                break
        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            num_values = tf.shape(input=value,
                                  out_type=util.tf_dtype(dtype='long'))[0]
        else:
            num_values = tf.dtypes.cast(x=tf.shape(input=value)[0],
                                        dtype=util.tf_dtype(dtype='long'))

        # Check whether instances fit into buffer
        assertion = tf.debugging.assert_less_equal(x=num_values, y=capacity)

        if self.return_overwritten:
            # Overwritten buffer indices
            with tf.control_dependencies(control_inputs=(assertion, )):
                start = tf.maximum(x=self.buffer_index, y=capacity)
                limit = tf.maximum(x=(self.buffer_index + num_values),
                                   y=capacity)
                num_overwritten = limit - start
                indices = tf.range(start=start, limit=limit)
                indices = tf.math.mod(x=indices, y=capacity)

            # Get overwritten values
            with tf.control_dependencies(control_inputs=(indices, )):
                overwritten_values = OrderedDict()
                for name, buffer in self.buffers.items():
                    if util.is_nested(name=name):
                        overwritten_values[name] = OrderedDict()
                        for inner_name, buffer in buffer.items():
                            overwritten_values[name][inner_name] = tf.gather(
                                params=buffer, indices=indices)
                    else:
                        overwritten_values[name] = tf.gather(params=buffer,
                                                             indices=indices)

        else:
            overwritten_values = (assertion, )

        # Buffer indices to (over)write
        with tf.control_dependencies(control_inputs=util.flatten(
                xs=overwritten_values)):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_values))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)

        # Write new values
        with tf.control_dependencies(control_inputs=(indices, )):
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    assignment = buffer.scatter_nd_update(indices=indices,
                                                          updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_values,
                                                      read_value=False)

        # Return overwritten values or no-op
        with tf.control_dependencies(control_inputs=(assignment, )):
            if self.return_overwritten:
                any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
                overwritten_values = util.fmap(
                    function=util.identity_operation, xs=overwritten_values)
                return any_overwritten, overwritten_values
            else:
                return util.no_operation()
Example #12
0
 def tf_core_observe(self, states, internals, auxiliaries, actions,
                     terminal, reward):
     return util.no_operation()
Example #13
0
        def update_overwritten_rewards():
            # Get relevant buffer rewards
            buffer_limit = self.buffer_index + tf.minimum(
                x=(num_overwritten + horizon), y=capacity)
            buffer_indices = tf.range(start=self.buffer_index,
                                      limit=buffer_limit)
            buffer_indices = tf.math.mod(x=buffer_indices, y=capacity)
            rewards = tf.gather(params=self.buffers['reward'],
                                indices=buffer_indices)

            # Get relevant values rewards
            values_limit = tf.maximum(x=(num_overwritten + horizon - capacity),
                                      y=zero)
            rewards = tf.concat(values=(rewards,
                                        values['reward'][:values_limit]),
                                axis=0)

            # Horizon baseline value
            if self.estimate_horizon == 'early':
                assert baseline is not None
                # Baseline estimate
                buffer_indices = buffer_indices[horizon + one:]
                states = OrderedDict()
                for name, buffer in self.buffers['states'].items():
                    state = tf.gather(params=buffer, indices=buffer_indices)
                    states[name] = tf.concat(
                        values=(state, values['states'][name][:values_limit]),
                        axis=0)
                internals = OrderedDict()
                for name, buffer in self.buffers['internals'].items():
                    internal = tf.gather(params=buffer, indices=buffer_indices)
                    internals[name] = tf.concat(
                        values=(internal,
                                values['internals'][name][:values_limit]),
                        axis=0)
                auxiliaries = OrderedDict()
                for name, buffer in self.buffers['auxiliaries'].items():
                    auxiliary = tf.gather(params=buffer,
                                          indices=buffer_indices)
                    auxiliaries[name] = tf.concat(
                        values=(auxiliary,
                                values['auxiliaries'][name][:values_limit]),
                        axis=0)

                # Dependency horizon
                # TODO: handle arbitrary non-optimization horizons!
                dependency_horizon = baseline.dependency_horizon(
                    is_optimization=False)
                assertion = tf.debugging.assert_equal(
                    x=dependency_horizon,
                    y=zero,
                    message=
                    "Temporary: baseline cannot depend on previous states.")
                with tf.control_dependencies(control_inputs=(assertion, )):
                    some_state = next(iter(states.values()))
                    if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                        batch_size = tf.shape(
                            input=some_state,
                            out_type=util.tf_dtype(dtype='long'))[0]
                    else:
                        batch_size = tf.dtypes.cast(
                            x=tf.shape(input=some_state)[0],
                            dtype=util.tf_dtype(dtype='long'))
                    starts = tf.range(start=batch_size,
                                      dtype=util.tf_dtype(dtype='long'))
                    lengths = tf.ones(shape=(batch_size, ),
                                      dtype=util.tf_dtype(dtype='long'))
                    Module.update_tensors(dependency_starts=starts,
                                          dependency_lengths=lengths)

                if self.estimate_actions:
                    actions = OrderedDict()
                    for name, buffer in self.buffers['actions'].items():
                        action = tf.gather(params=buffer,
                                           indices=buffer_indices)
                        actions[name] = tf.concat(
                            values=(action,
                                    values['actions'][name][:values_limit]),
                            axis=0)
                    horizon_estimate = baseline.actions_value(
                        states=states,
                        internals=internals,
                        auxiliaries=auxiliaries,
                        actions=actions)
                else:
                    horizon_estimate = baseline.states_value(
                        states=states,
                        internals=internals,
                        auxiliaries=auxiliaries)

            else:
                # Zero estimate
                horizon_estimate = tf.zeros(shape=(num_overwritten, ),
                                            dtype=util.tf_dtype(dtype='float'))

            # Calculate discounted sum
            def cond(discounted_sum, horizon):
                return tf.math.greater_equal(x=horizon, y=zero)

            def body(discounted_sum, horizon):
                # discounted_sum = tf.Print(discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10)
                discounted_sum = discount * discounted_sum
                discounted_sum = discounted_sum + rewards[horizon:horizon +
                                                          num_overwritten]
                return discounted_sum, horizon - one

            discounted_sum, _ = self.while_loop(cond=cond,
                                                body=body,
                                                loop_vars=(horizon_estimate,
                                                           horizon),
                                                back_prop=False)

            assertions = list()
            assertions.append(
                tf.debugging.assert_equal(x=tf.shape(input=horizon_estimate),
                                          y=tf.shape(input=discounted_sum),
                                          message="Estimation check."))
            assertions.append(
                tf.debugging.assert_equal(x=tf.shape(
                    input=rewards, out_type=util.tf_dtype(dtype='long'))[0],
                                          y=(horizon + num_overwritten),
                                          message="Estimation check."))

            # Overwrite buffer rewards
            with tf.control_dependencies(control_inputs=assertions):
                indices = tf.range(start=self.buffer_index,
                                   limit=(self.buffer_index + num_overwritten))
                indices = tf.math.mod(x=indices, y=capacity)
                indices = tf.expand_dims(input=indices, axis=1)

            assignment = self.buffers['reward'].scatter_nd_update(
                indices=indices, updates=discounted_sum)

            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()
Example #14
0
    def tf_enqueue(self, states, internals, auxiliaries, actions, terminal,
                   reward):
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        three = tf.constant(value=3, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            num_timesteps = tf.shape(input=terminal,
                                     out_type=util.tf_dtype(dtype='long'))[0]
        else:
            num_timesteps = tf.dtypes.cast(x=tf.shape(input=terminal)[0],
                                           dtype=util.tf_dtype(dtype='long'))

        # # Max capacity
        # latest_terminal_index = self.terminal_indices[self.episode_count]
        # max_capacity = self.buffer_index - latest_terminal_index - one
        # max_capacity = capacity - (tf.math.mod(x=max_capacity, y=capacity) + one)

        # Remove last observation terminal marker
        last_index = tf.math.mod(x=(self.buffer_index - one), y=capacity)
        last_terminal = tf.gather(params=self.buffers['terminal'],
                                  indices=(last_index, ))[0]
        corrected_terminal = tf.where(condition=tf.math.equal(x=last_terminal,
                                                              y=three),
                                      x=zero,
                                      y=last_terminal)
        assignment = tf.compat.v1.assign(
            ref=self.buffers['terminal'][last_index], value=corrected_terminal)

        # Assertions
        with tf.control_dependencies(control_inputs=(assignment, )):
            assertions = [
                # check: number of timesteps fit into effectively available buffer
                tf.debugging.assert_less_equal(
                    x=num_timesteps,
                    y=capacity,
                    message="Memory does not have enough capacity."),
                # at most one terminal
                tf.debugging.assert_less_equal(
                    x=tf.math.count_nonzero(input=terminal,
                                            dtype=util.tf_dtype(dtype='long')),
                    y=one,
                    message="Timesteps contain more than one terminal."),
                # if terminal, last timestep in batch
                tf.debugging.assert_equal(
                    x=tf.math.reduce_any(
                        input_tensor=tf.math.greater(x=terminal, y=zero)),
                    y=tf.math.greater(x=terminal[-1], y=zero),
                    message="Terminal is not the last timestep."),
                # general check: all terminal indices true
                tf.debugging.
                assert_equal(x=tf.reduce_all(input_tensor=tf.gather(
                    params=tf.math.greater(x=self.buffers['terminal'], y=zero),
                    indices=self.terminal_indices[:self.episode_count + one])),
                             y=tf.constant(value=True,
                                           dtype=util.tf_dtype(dtype='bool')),
                             message="Memory consistency check."),
                # general check: only terminal indices true
                tf.debugging.assert_equal(x=tf.math.count_nonzero(
                    input=self.buffers['terminal'],
                    dtype=util.tf_dtype(dtype='long')),
                                          y=(self.episode_count + one),
                                          message="Memory consistency check.")
            ]

        # Buffer indices to overwrite
        with tf.control_dependencies(control_inputs=assertions):
            overwritten_indices = tf.range(start=self.buffer_index,
                                           limit=(self.buffer_index +
                                                  num_timesteps))
            overwritten_indices = tf.math.mod(x=overwritten_indices,
                                              y=capacity)

            # Count number of overwritten episodes
            num_episodes = tf.math.count_nonzero(
                input=tf.gather(params=self.buffers['terminal'],
                                indices=overwritten_indices),
                axis=0,
                dtype=util.tf_dtype(dtype='long'))

            # Shift remaining terminal indices accordingly
            limit_index = self.episode_count + one
            assertion = tf.debugging.assert_greater_equal(
                x=limit_index,
                y=num_episodes,
                message="Memory episode overwriting check.")

        with tf.control_dependencies(control_inputs=(assertion, )):
            assignment = tf.compat.v1.assign(
                ref=self.terminal_indices[:limit_index - num_episodes],
                value=self.terminal_indices[num_episodes:limit_index])

        # Decrement episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment, )):
            assignment = self.episode_count.assign_sub(delta=num_episodes,
                                                       read_value=False)

        # Write new observations
        with tf.control_dependencies(control_inputs=(assignment, )):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_timesteps))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)
            values = dict(states=states,
                          internals=internals,
                          auxiliaries=auxiliaries,
                          actions=actions,
                          terminal=terminal,
                          reward=reward)
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    if name == 'terminal':
                        # Add last observation terminal marker
                        corrected_terminal = tf.where(condition=tf.math.equal(
                            x=terminal[-1], y=zero),
                                                      x=three,
                                                      y=terminal[-1])
                        assignment = buffer.scatter_nd_update(
                            indices=indices,
                            updates=tf.concat(values=(terminal[:-1],
                                                      (corrected_terminal, )),
                                              axis=0))
                    else:
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_timesteps,
                                                      read_value=False)

        # Count number of new episodes
        with tf.control_dependencies(control_inputs=(assignment, )):
            num_new_episodes = tf.math.count_nonzero(
                input=terminal, dtype=util.tf_dtype(dtype='long'))

            # Write new terminal indices
            limit_index = self.episode_count + one
            assignment = tf.compat.v1.assign(
                ref=self.terminal_indices[limit_index:limit_index +
                                          num_new_episodes],
                value=tf.boolean_mask(tensor=overwritten_indices,
                                      mask=tf.math.greater(x=terminal,
                                                           y=zero)))

        # Increment episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment, )):
            assignment = self.episode_count.assign_add(delta=num_new_episodes,
                                                       read_value=False)

        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.no_operation()
Example #15
0
    def tf_store(self, states, internals, actions, terminal, reward):
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        num_timesteps = tf.shape(input=terminal,
                                 out_type=util.tf_dtype(dtype='long'))[0]

        # Assertions
        assertions = list()
        # general check: all terminal indices true
        assertions.append(
            tf.debugging.assert_equal(x=tf.reduce_all(input_tensor=tf.gather(
                params=self.memories['terminal'],
                indices=self.terminal_indices[:self.episode_count + one])),
                                      y=tf.constant(
                                          value=True,
                                          dtype=util.tf_dtype(dtype='bool'))))
        # general check: only terminal indices true
        assertions.append(
            tf.debugging.assert_equal(x=tf.math.count_nonzero(
                input_tensor=self.memories['terminal'],
                dtype=util.tf_dtype(dtype='long')),
                                      y=(self.episode_count + one)))
        # instances fit into memory
        assertions.append(
            tf.debugging.assert_less_equal(x=num_timesteps, y=capacity))
        # at most one terminal
        assertions.append(
            tf.debugging.assert_less_equal(x=tf.math.count_nonzero(
                input_tensor=terminal, dtype=util.tf_dtype(dtype='long')),
                                           y=one))
        # if terminal, last timestep in batch
        assertions.append(
            tf.debugging.assert_equal(
                x=tf.math.reduce_any(input_tensor=terminal), y=terminal[-1]))

        # Memory indices to overwrite
        with tf.control_dependencies(control_inputs=assertions):
            indices = tf.range(start=self.memory_index,
                               limit=(self.memory_index + num_timesteps))
            indices = tf.mod(x=indices, y=capacity)

            # Count number of overwritten episodes
            num_episodes = tf.count_nonzero(input_tensor=tf.gather(
                params=self.memories['terminal'], indices=indices),
                                            axis=0,
                                            dtype=util.tf_dtype(dtype='long'))

            # Shift remaining terminal indices accordingly
            limit_index = self.episode_count + one
            assignment = tf.assign(
                ref=self.terminal_indices[:limit_index - num_episodes],
                value=self.terminal_indices[num_episodes:limit_index])

        # Decrement episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment, )):
            assignment = self.episode_count.assign_sub(delta=num_episodes,
                                                       read_value=False)

        # Write new observations
        with tf.control_dependencies(control_inputs=(assignment, )):
            assignments = list()
            for name, state in states.items():
                assignments.append(
                    tf.scatter_update(ref=self.memories[name],
                                      indices=indices,
                                      updates=state))
            for name, internal in internals.items():
                assignments.append(
                    tf.scatter_update(ref=self.memories[name],
                                      indices=indices,
                                      updates=internal))
            for name, action in actions.items():
                assignments.append(
                    tf.scatter_update(ref=self.memories[name],
                                      indices=indices,
                                      updates=action))
            assignments.append(
                tf.scatter_update(ref=self.memories['terminal'],
                                  indices=indices,
                                  updates=terminal))
            assignments.append(
                tf.scatter_update(ref=self.memories['reward'],
                                  indices=indices,
                                  updates=reward))

        # Increment memory index
        with tf.control_dependencies(control_inputs=assignments):
            new_memory_index = tf.mod(x=(self.memory_index + num_timesteps),
                                      y=capacity)
            assignment = self.memory_index.assign(value=new_memory_index)

        # Count number of new episodes
        with tf.control_dependencies(control_inputs=(assignment, )):
            num_new_episodes = tf.count_nonzero(
                input_tensor=terminal,
                axis=0,
                dtype=util.tf_dtype(dtype='long'))

            # Write new terminal indices
            limit_index = self.episode_count + one
            assignment = tf.assign(
                ref=self.terminal_indices[limit_index:limit_index +
                                          num_new_episodes],
                value=tf.boolean_mask(tensor=indices, mask=terminal))

        # Increment episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment, )):
            assignment = self.episode_count.assign_add(delta=num_new_episodes,
                                                       read_value=False)

        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.no_operation()
Example #16
0
    def tf_update_batch(self, loss_per_instance):
        """
        Updates priority memory by performing the following steps:

        1. Use saved indices from prior retrieval to reconstruct the batch
        elements which will have their priorities updated.
        2. Compute priorities for these elements.
        3. Insert buffer elements to memory, potentially overwriting existing elements.
        4. Update priorities of existing memory elements
        5. Resort memory.
        6. Update buffer insertion index.

        Note that this implementation could be made more efficient by maintaining
        a sorted version via sum trees.

        :param loss_per_instance: Losses from recent batch to perform priority update
        """
        # 1. We reconstruct the batch from the buffer and the priority memory via
        # the TensorFlow variables holding the respective indices.
        mask = tf.not_equal(x=self.batch_indices,
                            y=tf.zeros(
                                shape=tf.shape(input=self.batch_indices),
                                dtype=tf.int32))
        priority_indices = tf.reshape(tensor=tf.where(condition=mask),
                                      shape=[-1])

        # These are elements from the buffer which first need to be inserted into the main memory.
        sampled_buffer_batch = self.tf_retrieve_indices(
            buffer_elements=self.last_batch_buffer_elems,
            priority_indices=priority_indices)

        # Extract batch elements.
        states = sampled_buffer_batch['states']
        internals = sampled_buffer_batch['internals']
        actions = sampled_buffer_batch['actions']
        terminal = sampled_buffer_batch['terminal']
        reward = sampled_buffer_batch['reward']

        # 2. Compute priorities for all batch elements.
        priorities = loss_per_instance**self.prioritization_weight
        assignments = list()

        # 3. Insert the buffer elements from the recent batch into memory,
        # overwrite memory if full.
        memory_end_index = self.memory_index + self.last_batch_buffer_elems
        memory_insert_indices = tf.range(
            start=self.memory_index, limit=memory_end_index) % self.capacity

        for name in sorted(states):
            assignments.append(
                tf.scatter_update(
                    ref=self.states_memory[name],
                    indices=memory_insert_indices,
                    # Only buffer elements from batch.
                    updates=states[name][0:self.last_batch_buffer_elems]))
        for name in sorted(internals):
            assignments.append(
                tf.scatter_update(
                    ref=self.internals_buffer[name],
                    indices=memory_insert_indices,
                    updates=internals[name][0:self.last_batch_buffer_elems]))
        assignments.append(
            tf.scatter_update(
                ref=self.priorities,
                indices=memory_insert_indices,
                updates=priorities[0:self.last_batch_buffer_elems]))
        assignments.append(
            tf.scatter_update(
                ref=self.terminal_memory,
                indices=memory_insert_indices,
                updates=terminal[0:self.last_batch_buffer_elems]))
        assignments.append(
            tf.scatter_update(ref=self.reward_memory,
                              indices=memory_insert_indices,
                              updates=reward[0:self.last_batch_buffer_elems]))
        for name in sorted(actions):
            assignments.append(
                tf.scatter_update(
                    ref=self.actions_memory[name],
                    indices=memory_insert_indices,
                    updates=actions[name][0:self.last_batch_buffer_elems]))

        # 4.Update the priorities of the elements already in the memory.
        # Slice out remaining elements - [] if all batch elements were from buffer.
        main_memory_priorities = priorities[self.last_batch_buffer_elems:]
        # Note that priority indices can have a different shape because multiple
        # samples can be from the same index.
        main_memory_priorities = main_memory_priorities[
            0:tf.shape(priority_indices)[0]]
        assignments.append(
            tf.scatter_update(ref=self.priorities,
                              indices=priority_indices,
                              updates=main_memory_priorities))

        with tf.control_dependencies(control_inputs=assignments):
            # 5. Re-sort memory according to priorities.
            assignments = list()

            # Obtain sorted order and indices.
            sorted_priorities, sorted_indices = tf.nn.top_k(
                input=self.priorities, k=self.capacity, sorted=True)
            # Re-assign elements according to priorities.
            # Priorities was the tensor we used to sort, so this can be directly assigned.
            assignments.append(
                tf.assign(ref=self.priorities, value=sorted_priorities))

            # All other memory variables are assigned via scatter updates using the indices
            # returned by the sort:
            assignments.append(
                tf.scatter_update(ref=self.terminal_memory,
                                  indices=sorted_indices,
                                  updates=self.terminal_memory))
            for name in sorted(self.states_memory):
                assignments.append(
                    tf.scatter_update(ref=self.states_memory[name],
                                      indices=sorted_indices,
                                      updates=self.states_memory[name]))
            for name in sorted(self.actions_memory):
                assignments.append(
                    tf.scatter_update(ref=self.actions_memory[name],
                                      indices=sorted_indices,
                                      updates=self.actions_memory[name]))
            for name in sorted(self.internals_memory):
                assignments.append(
                    tf.scatter_update(ref=self.internals_memory[name],
                                      indices=sorted_indices,
                                      updates=self.internals_memory[name]))
            assignments.append(
                tf.scatter_update(ref=self.reward_memory,
                                  indices=sorted_indices,
                                  updates=self.reward_memory))

        # 6. Reset buffer index and increment memory index by inserted elements.
        with tf.control_dependencies(control_inputs=assignments):
            assignments = list()
            # Decrement pointer of last elements used.
            assignments.append(
                tf.assign_sub(ref=self.buffer_index,
                              value=self.last_batch_buffer_elems))

            # Keep track of memory size as to know whether we can sample from the main memory.
            # Since the memory pointer can set to 0, we want to know if we are at capacity.
            total_inserted_elements = self.memory_size + self.last_batch_buffer_elems
            assignments.append(
                tf.assign(ref=self.memory_size,
                          value=tf.minimum(x=total_inserted_elements,
                                           y=self.capacity)))

            # Update memory insertion index.
            assignments.append(
                tf.assign(ref=self.memory_index, value=memory_end_index))

            # Reset batch indices.
            assignments.append(
                tf.assign(ref=self.batch_indices,
                          value=tf.zeros(shape=tf.shape(self.batch_indices),
                                         dtype=tf.int32)))

        with tf.control_dependencies(control_inputs=assignments):
            return util.no_operation()