Ejemplo n.º 1
0
    def tf_retrieve(self, indices, values):
        # if values is None:
        #     is_single_value = False
        #     values = ('states', 'internals', 'actions', 'terminal', 'reward')
        if isinstance(values, str):
            is_single_value = True
            values = [values]
        else:
            is_single_value = False
            values = list(values)

        # Retrieve values
        for n, name in enumerate(values):
            if util.is_nested(name=name):
                value = OrderedDict()
                for inner_name in self.values_spec[name]:
                    value[inner_name] = tf.gather(
                        params=self.buffers[name][inner_name], indices=indices
                    )
            else:
                value = tf.gather(params=self.buffers[name], indices=indices)
            values[n] = value

        # # Stop gradients
        # values = util.fmap(function=tf.stop_gradient, xs=values)

        # Return values or single value
        if is_single_value:
            return values[0]
        else:
            return values
Ejemplo n.º 2
0
    def tf_initialize(self):
        super().tf_initialize()

        # Value buffers
        self.buffers = OrderedDict()
        for name, spec in self.values_spec.items():
            if util.is_nested(name=name):
                self.buffers[name] = OrderedDict()
                for inner_name, spec in spec.items():
                    shape = (self.capacity, ) + spec['shape']
                    initializer = self.initializers.get(inner_name, 'zeros')
                    self.buffers[name][inner_name] = self.add_variable(
                        name=(inner_name + '-buffer'),
                        dtype=spec['type'],
                        shape=shape,
                        is_trainable=False,
                        initializer=initializer)
            else:
                shape = (self.capacity, ) + spec['shape']
                initializer = self.initializers.get(name, 'zeros')
                self.buffers[name] = self.add_variable(name=(name + '-buffer'),
                                                       dtype=spec['type'],
                                                       shape=shape,
                                                       is_trainable=False,
                                                       initializer=initializer)

        # Buffer index (modulo capacity, next index to write to)
        self.buffer_index = self.add_variable(name='buffer-index',
                                              dtype='long',
                                              shape=(),
                                              is_trainable=False,
                                              initializer='zeros')
Ejemplo n.º 3
0
    def tf_initialize(self):
        super().tf_initialize()

        # Value buffers
        self.buffers = OrderedDict()
        for name, spec in self.values_spec.items():
            if util.is_nested(name=name):
                self.buffers[name] = OrderedDict()
                for inner_name, spec in spec.items():
                    shape = (self.capacity, ) + spec['shape']
                    self.buffers[name][inner_name] = self.add_variable(
                        name=(inner_name + '-buffer'),
                        dtype=spec['type'],
                        shape=shape,
                        is_trainable=False)
            else:
                shape = (self.capacity, ) + spec['shape']
                if name == 'terminal':
                    # Terminal initialization has to agree with terminal_indices
                    initializer = np.zeros(shape=(self.capacity, ),
                                           dtype=util.np_dtype(dtype='long'))
                    initializer[-1] = 1
                    self.buffers[name] = self.add_variable(
                        name=(name + '-buffer'),
                        dtype=spec['type'],
                        shape=shape,
                        is_trainable=False,
                        initializer=initializer)
                else:
                    self.buffers[name] = self.add_variable(name=(name +
                                                                 '-buffer'),
                                                           dtype=spec['type'],
                                                           shape=shape,
                                                           is_trainable=False)

        # Buffer index (modulo capacity, next index to write to)
        self.buffer_index = self.add_variable(name='buffer-index',
                                              dtype='long',
                                              shape=(),
                                              is_trainable=False,
                                              initializer='zeros')

        # Terminal indices
        # (oldest episode terminals first, initially the only terminal is last index)
        initializer = np.zeros(shape=(self.capacity + 1, ),
                               dtype=util.np_dtype(dtype='long'))
        initializer[0] = self.capacity - 1
        self.terminal_indices = self.add_variable(name='terminal-indices',
                                                  dtype='long',
                                                  shape=(self.capacity + 1, ),
                                                  is_trainable=False,
                                                  initializer=initializer)

        # Episode count
        self.episode_count = self.add_variable(name='episode-count',
                                               dtype='long',
                                               shape=(),
                                               is_trainable=False,
                                               initializer='zeros')
Ejemplo n.º 4
0
        def true_fn():
            reset_values = self.estimator.reset(baseline=self.baseline_policy)

            new_overwritten_values = OrderedDict()
            for name, value1, value2 in util.zip_items(overwritten_values, reset_values):
                if util.is_nested(name=name):
                    new_overwritten_values[name] = OrderedDict()
                    for inner_name, value1, value2 in util.zip_items(value1, value2):
                        new_overwritten_values[name][inner_name] = tf.concat(
                            values=(value1, value2), axis=0
                        )
                else:
                    new_overwritten_values[name] = tf.concat(values=(value1, value2), axis=0)
            return new_overwritten_values
Ejemplo n.º 5
0
    def tf_reset(self):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        if not self.return_overwritten:
            # Reset buffer index
            assignment = self.buffer_index.assign(value=zero, read_value=False)

            # Return no-op
            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()

        # Overwritten buffer indices
        num_values = tf.minimum(x=self.buffer_index, y=capacity)
        indices = tf.range(start=(self.buffer_index - num_values),
                           limit=self.buffer_index)
        indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        values = OrderedDict()
        for name, buffer in self.buffers.items():
            if util.is_nested(name=name):
                values[name] = OrderedDict()
                for inner_name, buffer in buffer.items():
                    values[name][inner_name] = tf.gather(params=buffer,
                                                         indices=indices)
            else:
                values[name] = tf.gather(params=buffer, indices=indices)

        # Reset buffer index
        with tf.control_dependencies(control_inputs=util.flatten(xs=values)):
            assignment = self.buffer_index.assign(value=zero, read_value=False)

        # Return overwritten values
        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.fmap(function=util.identity_operation, xs=values)
Ejemplo n.º 6
0
    def tf_successors(self, indices, horizon, sequence_values=(), final_values=()):
        if sequence_values == () and final_values == ():
            raise TensorforceError.unexpected()

        if isinstance(sequence_values, str):
            is_single_sequence_value = True
            sequence_values = [sequence_values]
        else:
            is_single_sequence_value = False
            sequence_values = list(sequence_values)
        if isinstance(final_values, str):
            is_single_final_value = True
            final_values = [final_values]
        else:
            is_single_final_value = False
            final_values = list(final_values)

        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity, dtype=util.tf_dtype(dtype='long'))

        def body(lengths, successor_indices, mask):
            current_index = successor_indices[:, -1:]
            current_terminal = self.retrieve(indices=current_index, values='terminal')
            is_not_terminal = tf.math.logical_and(
                x=tf.math.logical_not(x=tf.math.greater(x=current_terminal, y=zero)),
                y=mask[:, -1:]
            )
            next_index = tf.math.mod(x=(current_index + one), y=capacity)
            successor_indices = tf.concat(values=(successor_indices, next_index), axis=1)
            mask = tf.concat(values=(mask, is_not_terminal), axis=1)
            is_not_terminal = tf.squeeze(input=is_not_terminal, axis=1)
            zeros = tf.zeros_like(tensor=is_not_terminal, dtype=util.tf_dtype(dtype='long'))
            ones = tf.ones_like(tensor=is_not_terminal, dtype=util.tf_dtype(dtype='long'))
            lengths += tf.where(condition=is_not_terminal, x=ones, y=zeros)
            return lengths, successor_indices, mask

        lengths = tf.ones_like(tensor=indices, dtype=util.tf_dtype(dtype='long'))
        successor_indices = tf.expand_dims(input=indices, axis=1)
        mask = tf.ones_like(tensor=successor_indices, dtype=util.tf_dtype(dtype='bool'))
        shape = tf.TensorShape(dims=((None, None)))

        lengths, successor_indices, mask = self.while_loop(
            cond=util.tf_always_true, body=body, loop_vars=(lengths, successor_indices, mask),
            shape_invariants=(lengths.get_shape(), shape, shape), back_prop=False,
            maximum_iterations=horizon
        )

        successor_indices = tf.reshape(tensor=successor_indices, shape=(-1,))
        mask = tf.reshape(tensor=mask, shape=(-1,))
        successor_indices = tf.boolean_mask(tensor=successor_indices, mask=mask, axis=0)

        assertion = tf.compat.v1.debugging.assert_greater_equal(
            x=tf.math.mod(x=(self.buffer_index - one - successor_indices), y=capacity), y=zero
        )

        with tf.control_dependencies(control_inputs=(assertion,)):
            starts = tf.math.cumsum(x=lengths, exclusive=True)
            ends = tf.math.cumsum(x=lengths) - one
            final_indices = tf.gather(params=successor_indices, indices=ends)

            for n, name in enumerate(sequence_values):
                if util.is_nested(name=name):
                    sequence_value = OrderedDict()
                    for inner_name, spec in self.values_spec[name].items():
                        sequence_value[inner_name] = tf.gather(
                            params=self.buffers[name][inner_name], indices=successor_indices
                        )
                else:
                    sequence_value = tf.gather(
                        params=self.buffers[name], indices=successor_indices
                    )
                sequence_values[n] = sequence_value

            for n, name in enumerate(final_values):
                if util.is_nested(name=name):
                    final_value = OrderedDict()
                    for inner_name, spec in self.values_spec[name].items():
                        final_value[inner_name] = tf.gather(
                            params=self.buffers[name][inner_name], indices=final_indices
                        )
                else:
                    final_value = tf.gather(
                        params=self.buffers[name], indices=final_indices
                    )
                final_values[n] = final_value

        # def body(lengths, sequence_values, final_values):
        #     # Retrieve next indices
        #     next_indices = tf.math.mod(x=(indices - lengths), y=capacity)
        #     next_values = self.retrieve(
        #         indices=next_indices, values=(tuple(sequence_values) + tuple(final_values))
        #     )

        #     # Overwrite final values
        #     for name in final_values:
        #         final_values[name] = next_values[name]

        #     # Concatenate sequence values
        #     for name, value, next_value in util.zip_items(sequence_values, next_values):
        #         if util.is_nested(name=name):
        #             for inner_name, value, next_value in util.zip_items(value, next_value):
        #                 next_value = tf.expand_dims(input=next_value, axis=1)
        #                 sequence_values[name][inner_name] = tf.concat(
        #                     values=(value, next_value), axis=1
        #                 )
        #         else:
        #             next_value = tf.expand_dims(input=next_value, axis=1)
        #             sequence_values[name] = tf.concat(values=(value, next_value), axis=1)

        #     # Increment lengths unless start of episode
        #     with tf.control_dependencies(control_inputs=util.flatten(xs=next_values)):
        #         next_indices = tf.math.mod(x=(next_indices - one), y=capacity)
        #         terminal = self.retrieve(indices=next_indices, values='terminal')
        #         x = tf.zeros_like(tensor=terminal, dtype=util.tf_dtype(dtype='long'))
        #         y = tf.ones_like(tensor=terminal, dtype=util.tf_dtype(dtype='long'))
        #         lengths += tf.where(condition=terminal, x=x, y=y)

        #     return lengths, sequence_values, final_values

        # # Sequence lengths
        # lengths = tf.zeros_like(tensor=indices, dtype=util.tf_dtype(dtype='long'))

        # # Shape invariants
        # start_sequence_values = OrderedDict()
        # sequence_shapes = OrderedDict()
        # for name in sequence_values:
        #     if util.is_nested(name=name):
        #         start_sequence_values[name] = OrderedDict()
        #         sequence_shapes[name] = OrderedDict()
        #         for inner_name, spec in self.values_spec[name].items():
        #             start_sequence_values[name][inner_name] = tf.zeros(shape=((0, tf.shape(indices)[0]) + spec['shape']))
        #             shape = tf.TensorShape(dims=((None, None) + spec['shape']))
        #             sequence_shapes[name][inner_name] = shape
        #     else:
        #         start_sequence_values[name] = tf.zeros(shape=((0, tf.shape(indices)[0]) + self.values_spec[name]['shape']))
        #         shape = tf.TensorShape(dims=((None, None) + self.values_spec[name]['shape']))
        #         sequence_shapes[name] = shape
        # start_final_values = OrderedDict()
        # final_shapes = OrderedDict()
        # for name in final_values:
        #     if util.is_nested(name=name):
        #         start_final_values[name] = OrderedDict()
        #         final_shapes[name] = OrderedDict()
        #         for inner_name, spec in self.values_spec[name].items():
        #             start_final_values[name][inner_name] = tf.zeros(shape=((tf.shape(indices)[0],) + spec['shape']))
        #             shape = tf.TensorShape(dims=((None,) + spec['shape']))
        #             final_shapes[name][inner_name] = shape
        #     else:
        #         start_final_values[name] = tf.zeros(shape=((tf.shape(indices)[0],) + self.values_spec[name]['shape']))
        #         shape = tf.TensorShape(dims=((None,) + self.values_spec[name]['shape']))
        #         final_shapes[name] = shape

        # # Retrieve predecessors
        # lengths, sequence_values, final_values = self.while_loop(
        #     cond=util.tf_always_true, body=body,
        #     loop_vars=(lengths, start_sequence_values, start_final_values),
        #     shape_invariants=(lengths.get_shape(), sequence_shapes, final_shapes),
        #     back_prop=False, maximum_iterations=horizon
        # )

        # # Stop gradients
        # sequence_values = util.fmap(function=tf.stop_gradient, xs=sequence_values)
        # final_values = util.fmap(function=tf.stop_gradient, xs=final_values)

        if len(sequence_values) == 0:
            if is_single_final_value:
                final_values = final_values[0]
            return lengths, final_values

        elif len(final_values) == 0:
            if is_single_sequence_value:
                sequence_values = sequence_values[0]
            return starts, lengths, sequence_values

        else:
            if is_single_sequence_value:
                sequence_values = sequence_values[0]
            if is_single_final_value:
                final_values = final_values[0]
            return starts, lengths, sequence_values, final_values
Ejemplo n.º 7
0
    def tf_enqueue(self, **values):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        # Get number of values
        for value in values.values():
            if not isinstance(value, dict):
                break
            elif len(value) > 0:
                value = next(iter(value.values()))
                break
        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            num_values = tf.shape(input=value,
                                  out_type=util.tf_dtype(dtype='long'))[0]
        else:
            num_values = tf.dtypes.cast(x=tf.shape(input=value)[0],
                                        dtype=util.tf_dtype(dtype='long'))

        # Check whether instances fit into buffer
        assertion = tf.debugging.assert_less_equal(x=num_values, y=capacity)

        if self.return_overwritten:
            # Overwritten buffer indices
            with tf.control_dependencies(control_inputs=(assertion, )):
                start = tf.maximum(x=self.buffer_index, y=capacity)
                limit = tf.maximum(x=(self.buffer_index + num_values),
                                   y=capacity)
                num_overwritten = limit - start
                indices = tf.range(start=start, limit=limit)
                indices = tf.math.mod(x=indices, y=capacity)

            # Get overwritten values
            with tf.control_dependencies(control_inputs=(indices, )):
                overwritten_values = OrderedDict()
                for name, buffer in self.buffers.items():
                    if util.is_nested(name=name):
                        overwritten_values[name] = OrderedDict()
                        for inner_name, buffer in buffer.items():
                            overwritten_values[name][inner_name] = tf.gather(
                                params=buffer, indices=indices)
                    else:
                        overwritten_values[name] = tf.gather(params=buffer,
                                                             indices=indices)

        else:
            overwritten_values = (assertion, )

        # Buffer indices to (over)write
        with tf.control_dependencies(control_inputs=util.flatten(
                xs=overwritten_values)):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_values))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)

        # Write new values
        with tf.control_dependencies(control_inputs=(indices, )):
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    assignment = buffer.scatter_nd_update(indices=indices,
                                                          updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_values,
                                                      read_value=False)

        # Return overwritten values or no-op
        with tf.control_dependencies(control_inputs=(assignment, )):
            if self.return_overwritten:
                any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
                overwritten_values = util.fmap(
                    function=util.identity_operation, xs=overwritten_values)
                return any_overwritten, overwritten_values
            else:
                return util.no_operation()
Ejemplo n.º 8
0
    def tf_enqueue(self,
                   states,
                   internals,
                   auxiliaries,
                   actions,
                   terminal,
                   reward,
                   baseline=None):
        # Constants and parameters
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        horizon = self.horizon.value()
        discount = self.discount.value()

        assertions = list()
        # Check whether horizon at most capacity
        assertions.append(
            tf.debugging.assert_less_equal(
                x=horizon,
                y=capacity,
                message=
                "Estimator capacity has to be at least the same as the estimation horizon."
            ))
        # Check whether at most one terminal
        assertions.append(
            tf.debugging.assert_less_equal(
                x=tf.math.count_nonzero(input=terminal,
                                        dtype=util.tf_dtype(dtype='long')),
                y=one,
                message="Timesteps contain more than one terminal."))
        # Check whether, if any, last value is terminal
        assertions.append(
            tf.debugging.assert_equal(
                x=tf.reduce_any(
                    input_tensor=tf.math.greater(x=terminal, y=zero)),
                y=tf.math.greater(x=terminal[-1], y=zero),
                message="Terminal is not the last timestep."))

        # Get number of overwritten values
        with tf.control_dependencies(control_inputs=assertions):
            if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                num_values = tf.shape(input=terminal,
                                      out_type=util.tf_dtype(dtype='long'))[0]
            else:
                num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0],
                                            dtype=util.tf_dtype(dtype='long'))
            overwritten_start = tf.maximum(x=self.buffer_index, y=capacity)
            overwritten_limit = tf.maximum(x=(self.buffer_index + num_values),
                                           y=capacity)
            num_overwritten = overwritten_limit - overwritten_start

        def update_overwritten_rewards():
            # Get relevant buffer rewards
            buffer_limit = self.buffer_index + tf.minimum(
                x=(num_overwritten + horizon), y=capacity)
            buffer_indices = tf.range(start=self.buffer_index,
                                      limit=buffer_limit)
            buffer_indices = tf.math.mod(x=buffer_indices, y=capacity)
            rewards = tf.gather(params=self.buffers['reward'],
                                indices=buffer_indices)

            # Get relevant values rewards
            values_limit = tf.maximum(x=(num_overwritten + horizon - capacity),
                                      y=zero)
            rewards = tf.concat(values=(rewards, reward[:values_limit]),
                                axis=0)

            # Horizon baseline value
            if self.estimate_horizon == 'early':
                assert baseline is not None
                # Baseline estimate
                buffer_indices = buffer_indices[horizon + one:]
                _states = OrderedDict()
                for name, buffer in self.buffers['states'].items():
                    state = tf.gather(params=buffer, indices=buffer_indices)
                    _states[name] = tf.concat(
                        values=(state, states[name][:values_limit + one]),
                        axis=0)
                _internals = OrderedDict()
                for name, buffer in self.buffers['internals'].items():
                    internal = tf.gather(params=buffer, indices=buffer_indices)
                    _internals[name] = tf.concat(
                        values=(internal,
                                internals[name][:values_limit + one]),
                        axis=0)
                _auxiliaries = OrderedDict()
                for name, buffer in self.buffers['auxiliaries'].items():
                    auxiliary = tf.gather(params=buffer,
                                          indices=buffer_indices)
                    _auxiliaries[name] = tf.concat(
                        values=(auxiliary,
                                auxiliaries[name][:values_limit + one]),
                        axis=0)

                # Dependency horizon
                # TODO: handle arbitrary non-optimization horizons!
                past_horizon = baseline.past_horizon(is_optimization=False)
                assertion = tf.debugging.assert_equal(
                    x=past_horizon,
                    y=zero,
                    message=
                    "Temporary: baseline cannot depend on previous states.")
                with tf.control_dependencies(control_inputs=(assertion, )):
                    some_state = next(iter(_states.values()))
                    if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                        batch_size = tf.shape(
                            input=some_state,
                            out_type=util.tf_dtype(dtype='long'))[0]
                    else:
                        batch_size = tf.dtypes.cast(
                            x=tf.shape(input=some_state)[0],
                            dtype=util.tf_dtype(dtype='long'))
                    starts = tf.range(start=batch_size,
                                      dtype=util.tf_dtype(dtype='long'))
                    lengths = tf.ones(shape=(batch_size, ),
                                      dtype=util.tf_dtype(dtype='long'))
                    Module.update_tensors(dependency_starts=starts,
                                          dependency_lengths=lengths)

                if self.estimate_actions:
                    _actions = OrderedDict()
                    for name, buffer in self.buffers['actions'].items():
                        action = tf.gather(params=buffer,
                                           indices=buffer_indices)
                        _actions[name] = tf.concat(
                            values=(action, actions[name][:values_limit]),
                            axis=0)
                    horizon_estimate = baseline.actions_value(
                        states=_states,
                        internals=_internals,
                        auxiliaries=_auxiliaries,
                        actions=_actions)
                else:
                    horizon_estimate = baseline.states_value(
                        states=_states,
                        internals=_internals,
                        auxiliaries=_auxiliaries)

            else:
                # Zero estimate
                horizon_estimate = tf.zeros(shape=(num_overwritten, ),
                                            dtype=util.tf_dtype(dtype='float'))

            # Calculate discounted sum
            def cond(discounted_sum, horizon):
                return tf.math.greater_equal(x=horizon, y=zero)

            def body(discounted_sum, horizon):
                # discounted_sum = tf.compat.v1.Print(
                #     discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10
                # )
                discounted_sum = discount * discounted_sum
                discounted_sum = discounted_sum + rewards[horizon:horizon +
                                                          num_overwritten]
                return discounted_sum, horizon - one

            discounted_sum, _ = self.while_loop(cond=cond,
                                                body=body,
                                                loop_vars=(horizon_estimate,
                                                           horizon),
                                                back_prop=False)

            assertions = [
                tf.debugging.assert_equal(x=tf.shape(input=horizon_estimate),
                                          y=tf.shape(input=discounted_sum),
                                          message="Estimation check."),
                tf.debugging.assert_equal(x=tf.shape(
                    input=rewards, out_type=util.tf_dtype(dtype='long'))[0],
                                          y=(horizon + num_overwritten),
                                          message="Estimation check.")
            ]

            # Overwrite buffer rewards
            with tf.control_dependencies(control_inputs=assertions):
                indices = tf.range(start=self.buffer_index,
                                   limit=(self.buffer_index + num_overwritten))
                indices = tf.math.mod(x=indices, y=capacity)
                indices = tf.expand_dims(input=indices, axis=1)

            assignment = self.buffers['reward'].scatter_nd_update(
                indices=indices, updates=discounted_sum)

            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()

        any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
        updated_rewards = self.cond(pred=any_overwritten,
                                    true_fn=update_overwritten_rewards,
                                    false_fn=util.no_operation)

        # Overwritten buffer indices
        with tf.control_dependencies(control_inputs=(updated_rewards, )):
            indices = tf.range(start=overwritten_start,
                               limit=overwritten_limit)
            indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        with tf.control_dependencies(control_inputs=(indices, )):
            overwritten_values = OrderedDict()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    overwritten_values[name] = OrderedDict()
                    for inner_name, buffer in buffer.items():
                        overwritten_values[name][inner_name] = tf.gather(
                            params=buffer, indices=indices)
                else:
                    overwritten_values[name] = tf.gather(params=buffer,
                                                         indices=indices)

        # Buffer indices to (over)write
        with tf.control_dependencies(control_inputs=util.flatten(
                xs=overwritten_values)):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_values))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)

        # Write new values
        with tf.control_dependencies(control_inputs=(indices, )):
            values = dict(states=states,
                          internals=internals,
                          auxiliaries=auxiliaries,
                          actions=actions,
                          terminal=terminal,
                          reward=reward)
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    assignment = buffer.scatter_nd_update(indices=indices,
                                                          updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_values,
                                                      read_value=False)

        # Return overwritten values or no-op
        with tf.control_dependencies(control_inputs=(assignment, )):
            any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
            overwritten_values = util.fmap(function=util.identity_operation,
                                           xs=overwritten_values)
            return any_overwritten, overwritten_values
Ejemplo n.º 9
0
    def tf_reset(self, baseline=None):
        # Constants and parameters
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        horizon = self.horizon.value()
        discount = self.discount.value()

        # Overwritten buffer indices
        num_overwritten = tf.minimum(x=self.buffer_index, y=capacity)
        indices = tf.range(start=(self.buffer_index - num_overwritten),
                           limit=self.buffer_index)
        indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        values = OrderedDict()
        for name, buffer in self.buffers.items():
            if util.is_nested(name=name):
                values[name] = OrderedDict()
                for inner_name, buffer in buffer.items():
                    values[name][inner_name] = tf.gather(params=buffer,
                                                         indices=indices)
            else:
                values[name] = tf.gather(params=buffer, indices=indices)

        states = values['states']
        internals = values['internals']
        auxiliaries = values['auxiliaries']
        actions = values['actions']
        terminal = values['terminal']
        reward = values['reward']
        terminal = values['terminal']

        # Reset buffer index
        with tf.control_dependencies(control_inputs=util.flatten(xs=values)):
            assignment = self.buffer_index.assign(value=zero, read_value=False)

        with tf.control_dependencies(control_inputs=(assignment, )):
            assertions = list()
            # Check whether exactly one terminal (, unless empty?)
            assertions.append(
                tf.debugging.assert_equal(
                    x=tf.math.count_nonzero(input=terminal,
                                            dtype=util.tf_dtype(dtype='long')),
                    y=one,
                    message="Timesteps do not contain exactly one terminal."))
            # Check whether last value is terminal
            assertions.append(
                tf.debugging.assert_equal(
                    x=tf.math.greater(x=terminal[-1], y=zero),
                    y=tf.constant(value=True,
                                  dtype=util.tf_dtype(dtype='bool')),
                    message="Terminal is not the last timestep."))

        # Get number of values
        with tf.control_dependencies(control_inputs=assertions):
            if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                num_values = tf.shape(input=terminal,
                                      out_type=util.tf_dtype(dtype='long'))[0]
            else:
                num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0],
                                            dtype=util.tf_dtype(dtype='long'))

        # Horizon baseline value
        if self.estimate_horizon == 'early' and baseline is not None:
            # Dependency horizon
            # TODO: handle arbitrary non-optimization horizons!
            past_horizon = baseline.past_horizon(is_optimization=False)
            assertion = tf.debugging.assert_equal(
                x=past_horizon,
                y=zero,
                message="Temporary: baseline cannot depend on previous states."
            )

            # Baseline estimate
            horizon_start = num_values - tf.maximum(x=(num_values - horizon),
                                                    y=one)
            _states = OrderedDict()
            for name, state in states.items():
                _states[name] = state[horizon_start:]
            _internals = OrderedDict()
            for name, internal in internals.items():
                _internals[name] = internal[horizon_start:]
            _auxiliaries = OrderedDict()
            for name, auxiliary in auxiliaries.items():
                _auxiliaries[name] = auxiliary[horizon_start:]

            with tf.control_dependencies(control_inputs=(assertion, )):
                # some_state = next(iter(states.values()))
                # if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                #     batch_size = tf.shape(input=some_state, out_type=util.tf_dtype(dtype='long'))[0]
                # else:
                #     batch_size = tf.dtypes.cast(
                #         x=tf.shape(input=some_state)[0], dtype=util.tf_dtype(dtype='long')
                #     )
                batch_size = num_values - horizon_start
                starts = tf.range(start=batch_size,
                                  dtype=util.tf_dtype(dtype='long'))
                lengths = tf.ones(shape=(batch_size, ),
                                  dtype=util.tf_dtype(dtype='long'))
                Module.update_tensors(dependency_starts=starts,
                                      dependency_lengths=lengths)

            if self.estimate_actions:
                _actions = OrderedDict()
                for name, action in actions.items():
                    _actions[name] = action[horizon_start:]
                horizon_estimate = baseline.actions_value(
                    states=_states,
                    internals=_internals,
                    auxiliaries=_auxiliaries,
                    actions=_actions)
            else:
                horizon_estimate = baseline.states_value(
                    states=_states,
                    internals=_internals,
                    auxiliaries=_auxiliaries)

            # Expand rewards beyond terminal
            terminal_zeros = tf.zeros(shape=(horizon, ),
                                      dtype=util.tf_dtype(dtype='float'))
            if self.estimate_terminal:
                rewards = tf.concat(values=(reward[:-1], horizon_estimate[-1:],
                                            terminal_zeros),
                                    axis=0)

            else:
                with tf.control_dependencies(control_inputs=(assertion, )):
                    last_reward = tf.where(condition=tf.math.greater(
                        x=terminal[-1], y=one),
                                           x=horizon_estimate[-1],
                                           y=reward[-1])
                    rewards = tf.concat(values=(reward[:-1], (last_reward, ),
                                                terminal_zeros),
                                        axis=0)

            # Remove last if necessary
            horizon_end = tf.where(condition=tf.math.less_equal(x=num_values,
                                                                y=horizon),
                                   x=zero,
                                   y=(num_values - horizon))
            horizon_estimate = horizon_estimate[:horizon_end]

            # Expand missing estimates with zeros
            terminal_size = tf.minimum(x=horizon, y=num_values)
            terminal_estimate = tf.zeros(shape=(terminal_size, ),
                                         dtype=util.tf_dtype(dtype='float'))
            horizon_estimate = tf.concat(values=(horizon_estimate,
                                                 terminal_estimate),
                                         axis=0)

        else:
            # Expand rewards beyond terminal
            terminal_zeros = tf.zeros(shape=(horizon, ),
                                      dtype=util.tf_dtype(dtype='float'))
            rewards = tf.concat(values=(reward, terminal_zeros), axis=0)

            # Zero estimate
            horizon_estimate = tf.zeros(shape=(num_values, ),
                                        dtype=util.tf_dtype(dtype='float'))

        # Calculate discounted sum
        def cond(discounted_sum, horizon):
            return tf.math.greater_equal(x=horizon, y=zero)

        def body(discounted_sum, horizon):
            # discounted_sum = tf.compat.v1.Print(
            #     discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10
            # )
            discounted_sum = discount * discounted_sum
            discounted_sum = discounted_sum + rewards[horizon:horizon +
                                                      num_values]
            return discounted_sum, horizon - one

        values['reward'], _ = self.while_loop(cond=cond,
                                              body=body,
                                              loop_vars=(horizon_estimate,
                                                         horizon),
                                              back_prop=False)

        return values
Ejemplo n.º 10
0
    def tf_enqueue(self, states, internals, auxiliaries, actions, terminal,
                   reward):
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        three = tf.constant(value=3, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            num_timesteps = tf.shape(input=terminal,
                                     out_type=util.tf_dtype(dtype='long'))[0]
        else:
            num_timesteps = tf.dtypes.cast(x=tf.shape(input=terminal)[0],
                                           dtype=util.tf_dtype(dtype='long'))

        # # Max capacity
        # latest_terminal_index = self.terminal_indices[self.episode_count]
        # max_capacity = self.buffer_index - latest_terminal_index - one
        # max_capacity = capacity - (tf.math.mod(x=max_capacity, y=capacity) + one)

        # Remove last observation terminal marker
        last_index = tf.math.mod(x=(self.buffer_index - one), y=capacity)
        last_terminal = tf.gather(params=self.buffers['terminal'],
                                  indices=(last_index, ))[0]
        corrected_terminal = tf.where(condition=tf.math.equal(x=last_terminal,
                                                              y=three),
                                      x=zero,
                                      y=last_terminal)
        assignment = tf.compat.v1.assign(
            ref=self.buffers['terminal'][last_index], value=corrected_terminal)

        # Assertions
        with tf.control_dependencies(control_inputs=(assignment, )):
            assertions = [
                # check: number of timesteps fit into effectively available buffer
                tf.debugging.assert_less_equal(
                    x=num_timesteps,
                    y=capacity,
                    message="Memory does not have enough capacity."),
                # at most one terminal
                tf.debugging.assert_less_equal(
                    x=tf.math.count_nonzero(input=terminal,
                                            dtype=util.tf_dtype(dtype='long')),
                    y=one,
                    message="Timesteps contain more than one terminal."),
                # if terminal, last timestep in batch
                tf.debugging.assert_equal(
                    x=tf.math.reduce_any(
                        input_tensor=tf.math.greater(x=terminal, y=zero)),
                    y=tf.math.greater(x=terminal[-1], y=zero),
                    message="Terminal is not the last timestep."),
                # general check: all terminal indices true
                tf.debugging.
                assert_equal(x=tf.reduce_all(input_tensor=tf.gather(
                    params=tf.math.greater(x=self.buffers['terminal'], y=zero),
                    indices=self.terminal_indices[:self.episode_count + one])),
                             y=tf.constant(value=True,
                                           dtype=util.tf_dtype(dtype='bool')),
                             message="Memory consistency check."),
                # general check: only terminal indices true
                tf.debugging.assert_equal(x=tf.math.count_nonzero(
                    input=self.buffers['terminal'],
                    dtype=util.tf_dtype(dtype='long')),
                                          y=(self.episode_count + one),
                                          message="Memory consistency check.")
            ]

        # Buffer indices to overwrite
        with tf.control_dependencies(control_inputs=assertions):
            overwritten_indices = tf.range(start=self.buffer_index,
                                           limit=(self.buffer_index +
                                                  num_timesteps))
            overwritten_indices = tf.math.mod(x=overwritten_indices,
                                              y=capacity)

            # Count number of overwritten episodes
            num_episodes = tf.math.count_nonzero(
                input=tf.gather(params=self.buffers['terminal'],
                                indices=overwritten_indices),
                axis=0,
                dtype=util.tf_dtype(dtype='long'))

            # Shift remaining terminal indices accordingly
            limit_index = self.episode_count + one
            assertion = tf.debugging.assert_greater_equal(
                x=limit_index,
                y=num_episodes,
                message="Memory episode overwriting check.")

        with tf.control_dependencies(control_inputs=(assertion, )):
            assignment = tf.compat.v1.assign(
                ref=self.terminal_indices[:limit_index - num_episodes],
                value=self.terminal_indices[num_episodes:limit_index])

        # Decrement episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment, )):
            assignment = self.episode_count.assign_sub(delta=num_episodes,
                                                       read_value=False)

        # Write new observations
        with tf.control_dependencies(control_inputs=(assignment, )):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_timesteps))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)
            values = dict(states=states,
                          internals=internals,
                          auxiliaries=auxiliaries,
                          actions=actions,
                          terminal=terminal,
                          reward=reward)
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    if name == 'terminal':
                        # Add last observation terminal marker
                        corrected_terminal = tf.where(condition=tf.math.equal(
                            x=terminal[-1], y=zero),
                                                      x=three,
                                                      y=terminal[-1])
                        assignment = buffer.scatter_nd_update(
                            indices=indices,
                            updates=tf.concat(values=(terminal[:-1],
                                                      (corrected_terminal, )),
                                              axis=0))
                    else:
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_timesteps,
                                                      read_value=False)

        # Count number of new episodes
        with tf.control_dependencies(control_inputs=(assignment, )):
            num_new_episodes = tf.math.count_nonzero(
                input=terminal, dtype=util.tf_dtype(dtype='long'))

            # Write new terminal indices
            limit_index = self.episode_count + one
            assignment = tf.compat.v1.assign(
                ref=self.terminal_indices[limit_index:limit_index +
                                          num_new_episodes],
                value=tf.boolean_mask(tensor=overwritten_indices,
                                      mask=tf.math.greater(x=terminal,
                                                           y=zero)))

        # Increment episode count accordingly
        with tf.control_dependencies(control_inputs=(assignment, )):
            assignment = self.episode_count.assign_add(delta=num_new_episodes,
                                                       read_value=False)

        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.no_operation()