예제 #1
0
    def tf_optimization(
        self, states, internals, actions, terminal, reward, next_states=None, next_internals=None
    ):
        """
        Creates the TensorFlow operations for performing an optimization update step based
        on the given input states and actions batch.

        Args:
            states: Dict of state tensors.
            internals: List of prior internal state tensors.
            actions: Dict of action tensors.
            terminal: Terminal boolean tensor.
            reward: Reward tensor.
            next_states: Dict of successor state tensors.
            next_internals: List of posterior internal state tensors.

        Returns:
            The optimization operation.
        """
        parameters_before = OrderedDict()
        embedding = self.network.apply(x=states, internals=internals)
        for name, distribution in self.distributions.items():
            parameters_before[name] = distribution.parametrize(x=embedding)

        with tf.control_dependencies(control_inputs=util.flatten(xs=parameters_before)):
            optimized = super().tf_optimization(
                states=states, internals=internals, actions=actions, terminal=terminal,
                reward=reward, next_states=next_states, next_internals=next_internals
            )

        with tf.control_dependencies(control_inputs=(optimized,)):
            summaries = list()
            embedding = self.network.apply(x=states, internals=internals)
            for name, distribution in self.distributions.items():
                parameters = distribution.parametrize(x=embedding)
                kl_divergence = distribution.kl_divergence(
                    parameters1=parameters_before[name], parameters2=parameters
                )
                collapsed_size = util.product(xs=util.shape(kl_divergence)[1:])
                kl_divergence = tf.reshape(tensor=kl_divergence, shape=(-1, collapsed_size))
                kl_divergence = tf.reduce_mean(input_tensor=kl_divergence, axis=1)
                kl_divergence = self.add_summary(
                    label='kl-divergence', name=(name + '-kldiv'), tensor=kl_divergence
                )
                summaries.append(kl_divergence)

                entropy = distribution.entropy(parameters=parameters)
                entropy = tf.reshape(tensor=entropy, shape=(-1, collapsed_size))
                entropy = tf.reduce_mean(input_tensor=entropy, axis=1)
                entropy = self.add_summary(
                    label='entropy', name=(name + '-entropy'), tensor=entropy
                )
                summaries.append(entropy)

        with tf.control_dependencies(control_inputs=summaries):
            return util.no_operation()
예제 #2
0
    def tf_core_experience(self, states, internals, auxiliaries, actions, terminal, reward):
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))

        # Enqueue experience for early reward estimation
        any_overwritten, overwritten_values = self.estimator.enqueue(
            baseline=self.baseline_policy, states=states, internals=internals,
            auxiliaries=auxiliaries, actions=actions, terminal=terminal, reward=reward
        )

        # If terminal, store remaining values in memory

        def true_fn():
            reset_values = self.estimator.reset(baseline=self.baseline_policy)

            new_overwritten_values = OrderedDict()
            for name, value1, value2 in util.zip_items(overwritten_values, reset_values):
                if util.is_nested(name=name):
                    new_overwritten_values[name] = OrderedDict()
                    for inner_name, value1, value2 in util.zip_items(value1, value2):
                        new_overwritten_values[name][inner_name] = tf.concat(
                            values=(value1, value2), axis=0
                        )
                else:
                    new_overwritten_values[name] = tf.concat(values=(value1, value2), axis=0)
            return new_overwritten_values

        def false_fn():
            return overwritten_values

        with tf.control_dependencies(control_inputs=util.flatten(xs=overwritten_values)):
            values = self.cond(pred=(terminal[-1] > zero), true_fn=true_fn, false_fn=false_fn)

        # If any, store overwritten values
        def store():
            return self.memory.enqueue(**values)

        terminal = values['terminal']
        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            num_values = tf.shape(input=terminal, out_type=util.tf_dtype(dtype='long'))[0]
        else:
            num_values = tf.dtypes.cast(
                x=tf.shape(input=terminal)[0], dtype=util.tf_dtype(dtype='long')
            )

        stored = self.cond(pred=(num_values > zero), true_fn=store, false_fn=util.no_operation)

        return stored
    def tf_reset(self):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        if not self.return_overwritten:
            # Reset buffer index
            assignment = self.buffer_index.assign(value=zero, read_value=False)

            # Return no-op
            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()

        # Overwritten buffer indices
        num_values = tf.minimum(x=self.buffer_index, y=capacity)
        indices = tf.range(start=(self.buffer_index - num_values),
                           limit=self.buffer_index)
        indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        values = OrderedDict()
        for name, buffer in self.buffers.items():
            if util.is_nested(name=name):
                values[name] = OrderedDict()
                for inner_name, buffer in buffer.items():
                    values[name][inner_name] = tf.gather(params=buffer,
                                                         indices=indices)
            else:
                values[name] = tf.gather(params=buffer, indices=indices)

        # Reset buffer index
        with tf.control_dependencies(control_inputs=util.flatten(xs=values)):
            assignment = self.buffer_index.assign(value=zero, read_value=False)

        # Return overwritten values
        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.fmap(function=util.identity_operation, xs=values)
예제 #4
0
    def tf_optimize(self, indices):
        # Baseline optimization
        if self.baseline_optimizer is not None:
            optimized = self.optimize_baseline(indices=indices)
            dependencies = (optimized,)
        else:
            dependencies = (indices,)

        # Reward estimation
        with tf.control_dependencies(control_inputs=dependencies):
            reward = self.memory.retrieve(indices=indices, values='reward')
            reward = self.estimator.complete(
                baseline=self.baseline_policy, memory=self.memory, indices=indices, reward=reward
            )
            reward = self.add_summary(
                label=('empirical-reward', 'rewards'), name='empirical-reward', tensor=reward
            )
            is_baseline_optimized = self.separate_baseline_policy and \
                self.baseline_optimizer is None and self.baseline_objective is None
            reward = self.estimator.estimate(
                baseline=self.baseline_policy, memory=self.memory, indices=indices, reward=reward,
                is_baseline_optimized=is_baseline_optimized
            )
            reward = self.add_summary(
                label=('estimated-reward', 'rewards'), name='estimated-reward', tensor=reward
            )

        # Stop gradients of estimated rewards if separate baseline optimization
        if not is_baseline_optimized:
            reward = tf.stop_gradient(input=reward)

        # Retrieve states, internals and actions
        past_horizon = self.policy.past_horizon(is_optimization=True)
        if self.separate_baseline_policy and self.baseline_optimizer is None:
            assertion = tf.debugging.assert_equal(
                x=past_horizon,
                y=self.baseline_policy.past_horizon(is_optimization=True),
                message="Policy and baseline depend on a different number of previous states."
            )
        else:
            assertion = past_horizon

        with tf.control_dependencies(control_inputs=(assertion,)):
            # horizon change: see timestep-based batch sampling
            starts, lengths, states, internals = self.memory.predecessors(
                indices=indices, horizon=past_horizon, sequence_values='states',
                initial_values='internals'
            )
            Module.update_tensors(dependency_starts=starts, dependency_lengths=lengths)
            auxiliaries, actions = self.memory.retrieve(
                indices=indices, values=('auxiliaries', 'actions')
            )

        # Optimizer arguments
        independent = Module.update_tensor(
            name='independent', tensor=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool'))
        )

        variables = self.get_variables(only_trainable=True)

        arguments = dict(
            states=states, internals=internals, auxiliaries=auxiliaries, actions=actions,
            reward=reward
        )

        fn_loss = self.total_loss

        def fn_kl_divergence(states, internals, auxiliaries, actions, reward, other=None):
            kl_divergence = self.policy.kl_divergence(
                states=states, internals=internals, auxiliaries=auxiliaries, other=other
            )
            if self.baseline_optimizer is None and self.baseline_objective is not None:
                kl_divergence += self.baseline_policy.kl_divergence(
                    states=states, internals=internals, auxiliaries=auxiliaries, other=other
                )
            return kl_divergence

        if self.global_model is None:
            global_variables = None
        else:
            global_variables = self.global_model.get_variables(only_trainable=True)

        kwargs = self.objective.optimizer_arguments(
            policy=self.policy, baseline=self.baseline_policy
        )
        if self.baseline_optimizer is None and self.baseline_objective is not None:
            util.deep_disjoint_update(
                target=kwargs,
                source=self.baseline_objective.optimizer_arguments(policy=self.baseline_policy)
            )

        dependencies = util.flatten(xs=arguments)

        # KL divergence before
        if self.is_summary_logged(
            label=('kl-divergence', 'action-kl-divergences', 'kl-divergences')
        ):
            with tf.control_dependencies(control_inputs=dependencies):
                kldiv_reference = self.policy.kldiv_reference(
                    states=states, internals=internals, auxiliaries=auxiliaries
                )
                dependencies = util.flatten(xs=kldiv_reference)

        # Optimization
        with tf.control_dependencies(control_inputs=dependencies):
            optimized = self.optimizer.minimize(
                variables=variables, arguments=arguments, fn_loss=fn_loss,
                fn_kl_divergence=fn_kl_divergence, global_variables=global_variables, **kwargs
            )

        with tf.control_dependencies(control_inputs=(optimized,)):
            # Loss summaries
            if self.is_summary_logged(label=('loss', 'objective-loss', 'losses')):
                objective_loss = self.objective.loss_per_instance(policy=self.policy, **arguments)
                objective_loss = tf.math.reduce_mean(input_tensor=objective_loss, axis=0)
            if self.is_summary_logged(label=('objective-loss', 'losses')):
                optimized = self.add_summary(
                    label=('objective-loss', 'losses'), name='objective-loss',
                    tensor=objective_loss, pass_tensors=optimized
                )
            if self.is_summary_logged(label=('loss', 'regularization-loss', 'losses')):
                regularization_loss = self.regularize(
                    states=states, internals=internals, auxiliaries=auxiliaries
                )
            if self.is_summary_logged(label=('regularization-loss', 'losses')):
                optimized = self.add_summary(
                    label=('regularization-loss', 'losses'), name='regularization-loss',
                    tensor=regularization_loss, pass_tensors=optimized
                )
            if self.is_summary_logged(label=('loss', 'losses')):
                loss = objective_loss + regularization_loss
            if self.baseline_optimizer is None and self.baseline_objective is not None:
                if self.is_summary_logged(label=('loss', 'baseline-objective-loss', 'losses')):
                    if self.baseline_objective is None:
                        baseline_objective_loss = self.objective.loss_per_instance(
                            policy=self.baseline_policy, **arguments
                        )
                    else:
                        baseline_objective_loss = self.baseline_objective.loss_per_instance(
                            policy=self.baseline_policy, **arguments
                        )
                    baseline_objective_loss = tf.math.reduce_mean(
                        input_tensor=baseline_objective_loss, axis=0
                    )
                if self.is_summary_logged(label=('baseline-objective-loss', 'losses')):
                    optimized = self.add_summary(
                        label=('baseline-objective-loss', 'losses'),
                        name='baseline-objective-loss', tensor=baseline_objective_loss,
                        pass_tensors=optimized
                    )
                if self.is_summary_logged(
                    label=('loss', 'baseline-regularization-loss', 'losses')
                ):
                    baseline_regularization_loss = self.baseline_policy.regularize()
                if self.is_summary_logged(label=('baseline-regularization-loss', 'losses')):
                    optimized = self.add_summary(
                        label=('baseline-regularization-loss', 'losses'),
                        name='baseline-regularization-loss', tensor=baseline_regularization_loss,
                        pass_tensors=optimized
                    )
                if self.is_summary_logged(label=('loss', 'baseline-loss', 'losses')):
                    baseline_loss = baseline_objective_loss + baseline_regularization_loss
                if self.is_summary_logged(label=('baseline-loss', 'losses')):
                    optimized = self.add_summary(
                        label=('baseline-loss', 'losses'), name='baseline-loss',
                        tensor=baseline_loss, pass_tensors=optimized
                    )
                if self.is_summary_logged(label=('loss', 'losses')):
                    loss += self.baseline_loss_weight * baseline_loss
            if self.is_summary_logged(label=('loss', 'losses')):
                optimized = self.add_summary(
                    label=('loss', 'losses'), name='loss', tensor=loss, pass_tensors=optimized
                )

            # Entropy summaries
            if self.is_summary_logged(label=('entropy', 'action-entropies', 'entropies')):
                entropies = self.policy.entropy(
                    states=states, internals=internals, auxiliaries=auxiliaries,
                    include_per_action=(len(self.actions_spec) > 1)
                )
            if self.is_summary_logged(label=('entropy', 'entropies')):
                if len(self.actions_spec) == 1:
                    optimized = self.add_summary(
                        label=('entropy', 'entropies'), name='entropy', tensor=entropies,
                        pass_tensors=optimized
                    )
                else:
                    optimized = self.add_summary(
                        label=('entropy', 'entropies'), name='entropy', tensor=entropies['*'],
                        pass_tensors=optimized
                    )
            if len(self.actions_spec) > 1 and \
                    self.is_summary_logged(label=('action-entropies', 'entropies')):
                for name in self.actions_spec:
                    optimized = self.add_summary(
                        label=('action-entropies', 'entropies'), name=(name + '-entropy'),
                        tensor=entropies[name], pass_tensors=optimized
                    )

            # KL divergence summaries
            if self.is_summary_logged(
                label=('kl-divergence', 'action-kl-divergences', 'kl-divergences')
            ):
                kl_divergences = self.policy.kl_divergence(
                    states=states, internals=internals, auxiliaries=auxiliaries,
                    other=kldiv_reference, include_per_action=(len(self.actions_spec) > 1)
                )
            if self.is_summary_logged(label=('kl-divergence', 'kl-divergences')):
                if len(self.actions_spec) == 1:
                    optimized = self.add_summary(
                        label=('kl-divergence', 'kl-divergences'), name='kl-divergence',
                        tensor=kl_divergences, pass_tensors=optimized
                    )
                else:
                    optimized = self.add_summary(
                        label=('kl-divergence', 'kl-divergences'), name='kl-divergence',
                        tensor=kl_divergences['*'], pass_tensors=optimized
                    )
            if len(self.actions_spec) > 1 and \
                    self.is_summary_logged(label=('action-kl-divergences', 'kl-divergences')):
                for name in self.actions_spec:
                    optimized = self.add_summary(
                        label=('action-kl-divergences', 'kl-divergences'),
                        name=(name + '-kl-divergence'), tensor=kl_divergences[name],
                        pass_tensors=optimized
                    )

        Module.update_tensor(name='independent', tensor=independent)

        return optimized
    def tf_enqueue(self, **values):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        # Get number of values
        for value in values.values():
            if not isinstance(value, dict):
                break
            elif len(value) > 0:
                value = next(iter(value.values()))
                break
        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            num_values = tf.shape(input=value,
                                  out_type=util.tf_dtype(dtype='long'))[0]
        else:
            num_values = tf.dtypes.cast(x=tf.shape(input=value)[0],
                                        dtype=util.tf_dtype(dtype='long'))

        # Check whether instances fit into buffer
        assertion = tf.debugging.assert_less_equal(x=num_values, y=capacity)

        if self.return_overwritten:
            # Overwritten buffer indices
            with tf.control_dependencies(control_inputs=(assertion, )):
                start = tf.maximum(x=self.buffer_index, y=capacity)
                limit = tf.maximum(x=(self.buffer_index + num_values),
                                   y=capacity)
                num_overwritten = limit - start
                indices = tf.range(start=start, limit=limit)
                indices = tf.math.mod(x=indices, y=capacity)

            # Get overwritten values
            with tf.control_dependencies(control_inputs=(indices, )):
                overwritten_values = OrderedDict()
                for name, buffer in self.buffers.items():
                    if util.is_nested(name=name):
                        overwritten_values[name] = OrderedDict()
                        for inner_name, buffer in buffer.items():
                            overwritten_values[name][inner_name] = tf.gather(
                                params=buffer, indices=indices)
                    else:
                        overwritten_values[name] = tf.gather(params=buffer,
                                                             indices=indices)

        else:
            overwritten_values = (assertion, )

        # Buffer indices to (over)write
        with tf.control_dependencies(control_inputs=util.flatten(
                xs=overwritten_values)):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_values))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)

        # Write new values
        with tf.control_dependencies(control_inputs=(indices, )):
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    assignment = buffer.scatter_nd_update(indices=indices,
                                                          updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_values,
                                                      read_value=False)

        # Return overwritten values or no-op
        with tf.control_dependencies(control_inputs=(assignment, )):
            if self.return_overwritten:
                any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
                overwritten_values = util.fmap(
                    function=util.identity_operation, xs=overwritten_values)
                return any_overwritten, overwritten_values
            else:
                return util.no_operation()
예제 #6
0
    def create_api_function(self, name, api_function):
        # Call API TensorFlow function
        Module.global_scope = list()
        Module.scope_stack = list()
        Module.while_counter = 0
        Module.cond_counter = 0
        Module.global_tensors = OrderedDict()
        Module.queryable_tensors = OrderedDict()

        if self.device is not None:
            self.device.__enter__()
        scope = tf.name_scope(name=name)
        Module.scope_stack.append(scope)
        scope.__enter__()

        results = api_function()
        self.output_tensors[name[name.index('.') + 1:]] = sorted(
            x.name[len(name) + 1: -9] for x in util.flatten(xs=results)
        )

        # Function-level identity operation for retrieval
        query_tensors = set()
        for scoped_name, tensor in Module.queryable_tensors.items():
            util.identity_operation(x=tensor, operation_name=(scoped_name + '-output'))
            assert scoped_name not in query_tensors
            query_tensors.add(scoped_name)
        self.query_tensors[name[name.index('.') + 1:]] = sorted(query_tensors)

        scope.__exit__(None, None, None)
        Module.scope_stack.pop()
        if self.device is not None:
            self.device.__exit__(None, None, None)

        assert len(Module.global_scope) == 0
        Module.global_scope = None
        assert len(Module.scope_stack) == 0
        Module.scope_stack = None
        Module.while_counter = None
        Module.cond_counter = None
        Module.global_tensors = None
        Module.queryable_tensors = None

        def fn(query=None, **kwargs):
            # Feed_dict dictionary
            feed_dict = dict()
            for key, arg in kwargs.items():
                if arg is None:
                    continue
                elif isinstance(arg, dict):
                    # Support single nesting (for states, internals, actions)
                    for key, arg in arg.items():
                        feed_dict[util.join_scopes(self.name, key) + '-input:0'] = arg
                else:
                    feed_dict[util.join_scopes(self.name, key) + '-input:0'] = arg
            if not all(isinstance(x, str) and x.endswith('-input:0') for x in feed_dict):
                raise TensorforceError.value(
                    name=api_function, argument='inputs', value=list(feed_dict)
                )

            # Fetches value/tuple
            fetches = util.fmap(function=(lambda x: x.name), xs=results)
            if query is not None:
                # If additional tensors are to be fetched
                query = util.fmap(
                    function=(lambda x: util.join_scopes(name, x) + '-output:0'), xs=query
                )
                if util.is_iterable(x=fetches):
                    fetches = tuple(fetches) + (query,)
                else:
                    fetches = (fetches, query)
            if not util.reduce_all(
                predicate=(lambda x: isinstance(x, str) and x.endswith('-output:0')), xs=fetches
            ):
                raise TensorforceError.value(
                    name=api_function, argument='outputs', value=list(fetches)
                )

            # TensorFlow session call
            fetched = self.monitored_session.run(fetches=fetches, feed_dict=feed_dict)

            return fetched

        return fn
예제 #7
0
    def tf_enqueue(self,
                   states,
                   internals,
                   auxiliaries,
                   actions,
                   terminal,
                   reward,
                   baseline=None):
        # Constants and parameters
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        horizon = self.horizon.value()
        discount = self.discount.value()

        assertions = list()
        # Check whether horizon at most capacity
        assertions.append(
            tf.debugging.assert_less_equal(
                x=horizon,
                y=capacity,
                message=
                "Estimator capacity has to be at least the same as the estimation horizon."
            ))
        # Check whether at most one terminal
        assertions.append(
            tf.debugging.assert_less_equal(
                x=tf.math.count_nonzero(input=terminal,
                                        dtype=util.tf_dtype(dtype='long')),
                y=one,
                message="Timesteps contain more than one terminal."))
        # Check whether, if any, last value is terminal
        assertions.append(
            tf.debugging.assert_equal(
                x=tf.reduce_any(
                    input_tensor=tf.math.greater(x=terminal, y=zero)),
                y=tf.math.greater(x=terminal[-1], y=zero),
                message="Terminal is not the last timestep."))

        # Get number of overwritten values
        with tf.control_dependencies(control_inputs=assertions):
            if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                num_values = tf.shape(input=terminal,
                                      out_type=util.tf_dtype(dtype='long'))[0]
            else:
                num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0],
                                            dtype=util.tf_dtype(dtype='long'))
            overwritten_start = tf.maximum(x=self.buffer_index, y=capacity)
            overwritten_limit = tf.maximum(x=(self.buffer_index + num_values),
                                           y=capacity)
            num_overwritten = overwritten_limit - overwritten_start

        def update_overwritten_rewards():
            # Get relevant buffer rewards
            buffer_limit = self.buffer_index + tf.minimum(
                x=(num_overwritten + horizon), y=capacity)
            buffer_indices = tf.range(start=self.buffer_index,
                                      limit=buffer_limit)
            buffer_indices = tf.math.mod(x=buffer_indices, y=capacity)
            rewards = tf.gather(params=self.buffers['reward'],
                                indices=buffer_indices)

            # Get relevant values rewards
            values_limit = tf.maximum(x=(num_overwritten + horizon - capacity),
                                      y=zero)
            rewards = tf.concat(values=(rewards, reward[:values_limit]),
                                axis=0)

            # Horizon baseline value
            if self.estimate_horizon == 'early':
                assert baseline is not None
                # Baseline estimate
                buffer_indices = buffer_indices[horizon + one:]
                _states = OrderedDict()
                for name, buffer in self.buffers['states'].items():
                    state = tf.gather(params=buffer, indices=buffer_indices)
                    _states[name] = tf.concat(
                        values=(state, states[name][:values_limit + one]),
                        axis=0)
                _internals = OrderedDict()
                for name, buffer in self.buffers['internals'].items():
                    internal = tf.gather(params=buffer, indices=buffer_indices)
                    _internals[name] = tf.concat(
                        values=(internal,
                                internals[name][:values_limit + one]),
                        axis=0)
                _auxiliaries = OrderedDict()
                for name, buffer in self.buffers['auxiliaries'].items():
                    auxiliary = tf.gather(params=buffer,
                                          indices=buffer_indices)
                    _auxiliaries[name] = tf.concat(
                        values=(auxiliary,
                                auxiliaries[name][:values_limit + one]),
                        axis=0)

                # Dependency horizon
                # TODO: handle arbitrary non-optimization horizons!
                past_horizon = baseline.past_horizon(is_optimization=False)
                assertion = tf.debugging.assert_equal(
                    x=past_horizon,
                    y=zero,
                    message=
                    "Temporary: baseline cannot depend on previous states.")
                with tf.control_dependencies(control_inputs=(assertion, )):
                    some_state = next(iter(_states.values()))
                    if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                        batch_size = tf.shape(
                            input=some_state,
                            out_type=util.tf_dtype(dtype='long'))[0]
                    else:
                        batch_size = tf.dtypes.cast(
                            x=tf.shape(input=some_state)[0],
                            dtype=util.tf_dtype(dtype='long'))
                    starts = tf.range(start=batch_size,
                                      dtype=util.tf_dtype(dtype='long'))
                    lengths = tf.ones(shape=(batch_size, ),
                                      dtype=util.tf_dtype(dtype='long'))
                    Module.update_tensors(dependency_starts=starts,
                                          dependency_lengths=lengths)

                if self.estimate_actions:
                    _actions = OrderedDict()
                    for name, buffer in self.buffers['actions'].items():
                        action = tf.gather(params=buffer,
                                           indices=buffer_indices)
                        _actions[name] = tf.concat(
                            values=(action, actions[name][:values_limit]),
                            axis=0)
                    horizon_estimate = baseline.actions_value(
                        states=_states,
                        internals=_internals,
                        auxiliaries=_auxiliaries,
                        actions=_actions)
                else:
                    horizon_estimate = baseline.states_value(
                        states=_states,
                        internals=_internals,
                        auxiliaries=_auxiliaries)

            else:
                # Zero estimate
                horizon_estimate = tf.zeros(shape=(num_overwritten, ),
                                            dtype=util.tf_dtype(dtype='float'))

            # Calculate discounted sum
            def cond(discounted_sum, horizon):
                return tf.math.greater_equal(x=horizon, y=zero)

            def body(discounted_sum, horizon):
                # discounted_sum = tf.compat.v1.Print(
                #     discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10
                # )
                discounted_sum = discount * discounted_sum
                discounted_sum = discounted_sum + rewards[horizon:horizon +
                                                          num_overwritten]
                return discounted_sum, horizon - one

            discounted_sum, _ = self.while_loop(cond=cond,
                                                body=body,
                                                loop_vars=(horizon_estimate,
                                                           horizon),
                                                back_prop=False)

            assertions = [
                tf.debugging.assert_equal(x=tf.shape(input=horizon_estimate),
                                          y=tf.shape(input=discounted_sum),
                                          message="Estimation check."),
                tf.debugging.assert_equal(x=tf.shape(
                    input=rewards, out_type=util.tf_dtype(dtype='long'))[0],
                                          y=(horizon + num_overwritten),
                                          message="Estimation check.")
            ]

            # Overwrite buffer rewards
            with tf.control_dependencies(control_inputs=assertions):
                indices = tf.range(start=self.buffer_index,
                                   limit=(self.buffer_index + num_overwritten))
                indices = tf.math.mod(x=indices, y=capacity)
                indices = tf.expand_dims(input=indices, axis=1)

            assignment = self.buffers['reward'].scatter_nd_update(
                indices=indices, updates=discounted_sum)

            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()

        any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
        updated_rewards = self.cond(pred=any_overwritten,
                                    true_fn=update_overwritten_rewards,
                                    false_fn=util.no_operation)

        # Overwritten buffer indices
        with tf.control_dependencies(control_inputs=(updated_rewards, )):
            indices = tf.range(start=overwritten_start,
                               limit=overwritten_limit)
            indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        with tf.control_dependencies(control_inputs=(indices, )):
            overwritten_values = OrderedDict()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    overwritten_values[name] = OrderedDict()
                    for inner_name, buffer in buffer.items():
                        overwritten_values[name][inner_name] = tf.gather(
                            params=buffer, indices=indices)
                else:
                    overwritten_values[name] = tf.gather(params=buffer,
                                                         indices=indices)

        # Buffer indices to (over)write
        with tf.control_dependencies(control_inputs=util.flatten(
                xs=overwritten_values)):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_values))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)

        # Write new values
        with tf.control_dependencies(control_inputs=(indices, )):
            values = dict(states=states,
                          internals=internals,
                          auxiliaries=auxiliaries,
                          actions=actions,
                          terminal=terminal,
                          reward=reward)
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    assignment = buffer.scatter_nd_update(indices=indices,
                                                          updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_values,
                                                      read_value=False)

        # Return overwritten values or no-op
        with tf.control_dependencies(control_inputs=(assignment, )):
            any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
            overwritten_values = util.fmap(function=util.identity_operation,
                                           xs=overwritten_values)
            return any_overwritten, overwritten_values
예제 #8
0
    def tf_reset(self, baseline=None):
        # Constants and parameters
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        horizon = self.horizon.value()
        discount = self.discount.value()

        # Overwritten buffer indices
        num_overwritten = tf.minimum(x=self.buffer_index, y=capacity)
        indices = tf.range(start=(self.buffer_index - num_overwritten),
                           limit=self.buffer_index)
        indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        values = OrderedDict()
        for name, buffer in self.buffers.items():
            if util.is_nested(name=name):
                values[name] = OrderedDict()
                for inner_name, buffer in buffer.items():
                    values[name][inner_name] = tf.gather(params=buffer,
                                                         indices=indices)
            else:
                values[name] = tf.gather(params=buffer, indices=indices)

        states = values['states']
        internals = values['internals']
        auxiliaries = values['auxiliaries']
        actions = values['actions']
        terminal = values['terminal']
        reward = values['reward']
        terminal = values['terminal']

        # Reset buffer index
        with tf.control_dependencies(control_inputs=util.flatten(xs=values)):
            assignment = self.buffer_index.assign(value=zero, read_value=False)

        with tf.control_dependencies(control_inputs=(assignment, )):
            assertions = list()
            # Check whether exactly one terminal (, unless empty?)
            assertions.append(
                tf.debugging.assert_equal(
                    x=tf.math.count_nonzero(input=terminal,
                                            dtype=util.tf_dtype(dtype='long')),
                    y=one,
                    message="Timesteps do not contain exactly one terminal."))
            # Check whether last value is terminal
            assertions.append(
                tf.debugging.assert_equal(
                    x=tf.math.greater(x=terminal[-1], y=zero),
                    y=tf.constant(value=True,
                                  dtype=util.tf_dtype(dtype='bool')),
                    message="Terminal is not the last timestep."))

        # Get number of values
        with tf.control_dependencies(control_inputs=assertions):
            if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                num_values = tf.shape(input=terminal,
                                      out_type=util.tf_dtype(dtype='long'))[0]
            else:
                num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0],
                                            dtype=util.tf_dtype(dtype='long'))

        # Horizon baseline value
        if self.estimate_horizon == 'early' and baseline is not None:
            # Dependency horizon
            # TODO: handle arbitrary non-optimization horizons!
            past_horizon = baseline.past_horizon(is_optimization=False)
            assertion = tf.debugging.assert_equal(
                x=past_horizon,
                y=zero,
                message="Temporary: baseline cannot depend on previous states."
            )

            # Baseline estimate
            horizon_start = num_values - tf.maximum(x=(num_values - horizon),
                                                    y=one)
            _states = OrderedDict()
            for name, state in states.items():
                _states[name] = state[horizon_start:]
            _internals = OrderedDict()
            for name, internal in internals.items():
                _internals[name] = internal[horizon_start:]
            _auxiliaries = OrderedDict()
            for name, auxiliary in auxiliaries.items():
                _auxiliaries[name] = auxiliary[horizon_start:]

            with tf.control_dependencies(control_inputs=(assertion, )):
                # some_state = next(iter(states.values()))
                # if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                #     batch_size = tf.shape(input=some_state, out_type=util.tf_dtype(dtype='long'))[0]
                # else:
                #     batch_size = tf.dtypes.cast(
                #         x=tf.shape(input=some_state)[0], dtype=util.tf_dtype(dtype='long')
                #     )
                batch_size = num_values - horizon_start
                starts = tf.range(start=batch_size,
                                  dtype=util.tf_dtype(dtype='long'))
                lengths = tf.ones(shape=(batch_size, ),
                                  dtype=util.tf_dtype(dtype='long'))
                Module.update_tensors(dependency_starts=starts,
                                      dependency_lengths=lengths)

            if self.estimate_actions:
                _actions = OrderedDict()
                for name, action in actions.items():
                    _actions[name] = action[horizon_start:]
                horizon_estimate = baseline.actions_value(
                    states=_states,
                    internals=_internals,
                    auxiliaries=_auxiliaries,
                    actions=_actions)
            else:
                horizon_estimate = baseline.states_value(
                    states=_states,
                    internals=_internals,
                    auxiliaries=_auxiliaries)

            # Expand rewards beyond terminal
            terminal_zeros = tf.zeros(shape=(horizon, ),
                                      dtype=util.tf_dtype(dtype='float'))
            if self.estimate_terminal:
                rewards = tf.concat(values=(reward[:-1], horizon_estimate[-1:],
                                            terminal_zeros),
                                    axis=0)

            else:
                with tf.control_dependencies(control_inputs=(assertion, )):
                    last_reward = tf.where(condition=tf.math.greater(
                        x=terminal[-1], y=one),
                                           x=horizon_estimate[-1],
                                           y=reward[-1])
                    rewards = tf.concat(values=(reward[:-1], (last_reward, ),
                                                terminal_zeros),
                                        axis=0)

            # Remove last if necessary
            horizon_end = tf.where(condition=tf.math.less_equal(x=num_values,
                                                                y=horizon),
                                   x=zero,
                                   y=(num_values - horizon))
            horizon_estimate = horizon_estimate[:horizon_end]

            # Expand missing estimates with zeros
            terminal_size = tf.minimum(x=horizon, y=num_values)
            terminal_estimate = tf.zeros(shape=(terminal_size, ),
                                         dtype=util.tf_dtype(dtype='float'))
            horizon_estimate = tf.concat(values=(horizon_estimate,
                                                 terminal_estimate),
                                         axis=0)

        else:
            # Expand rewards beyond terminal
            terminal_zeros = tf.zeros(shape=(horizon, ),
                                      dtype=util.tf_dtype(dtype='float'))
            rewards = tf.concat(values=(reward, terminal_zeros), axis=0)

            # Zero estimate
            horizon_estimate = tf.zeros(shape=(num_values, ),
                                        dtype=util.tf_dtype(dtype='float'))

        # Calculate discounted sum
        def cond(discounted_sum, horizon):
            return tf.math.greater_equal(x=horizon, y=zero)

        def body(discounted_sum, horizon):
            # discounted_sum = tf.compat.v1.Print(
            #     discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10
            # )
            discounted_sum = discount * discounted_sum
            discounted_sum = discounted_sum + rewards[horizon:horizon +
                                                      num_values]
            return discounted_sum, horizon - one

        values['reward'], _ = self.while_loop(cond=cond,
                                              body=body,
                                              loop_vars=(horizon_estimate,
                                                         horizon),
                                              back_prop=False)

        return values