Ejemplo n.º 1
0
    def tf_iterative_step(self, x, previous):
        state = previous['state']

        if self.cell_type == 'gru':
            state = (state, )
        elif self.cell_type == 'lstm':
            state = (state[:, 0, :], state[:, 1, :])

        if util.tf_dtype(dtype='float') not in (tf.float32, tf.float64):
            x = tf.dtypes.cast(x=x, dtype=tf.float32)
            state = util.fmap(
                function=(lambda x: tf.dtypes.cast(x=x, dtype=tf.float32)),
                xs=state)
            state = tf.dtypes.cast(x=state, dtype=tf.float32)

        x, state = self.cell(inputs=x, states=state)

        if util.tf_dtype(dtype='float') not in (tf.float32, tf.float64):
            x = tf.dtypes.cast(x=x, dtype=util.tf_dtype(dtype='float'))
            state = util.fmap(function=(lambda x: tf.dtypes.cast(
                x=x, dtype=util.tf_dtype(dtype='float'))),
                              xs=state)

        if self.cell_type == 'gru':
            state = state[0]
        elif self.cell_type == 'lstm':
            state = tf.stack(values=state, axis=1)

        return x, OrderedDict(state=state)
Ejemplo n.º 2
0
        def fn(query=None, **kwargs):
            # Feed_dict dictionary
            feed_dict = dict()
            for key, arg in kwargs.items():
                if arg is None:
                    continue
                elif isinstance(arg, dict):
                    # Support single nesting (for states, internals, actions)
                    for key, arg in arg.items():
                        feed_dict[util.join_scopes(self.name, key) + '-input:0'] = arg
                else:
                    feed_dict[util.join_scopes(self.name, key) + '-input:0'] = arg
            if not all(isinstance(x, str) and x.endswith('-input:0') for x in feed_dict):
                raise TensorforceError.unexpected()

            # Fetches value/tuple
            fetches = util.fmap(function=(lambda x: x.name), xs=results)
            if query is not None:
                # If additional tensors are to be fetched
                query = util.fmap(
                    function=(lambda x: util.join_scopes(name, x) + '-output:0'), xs=query
                )
                if util.is_iterable(x=fetches):
                    fetches = tuple(fetches) + (query,)
                else:
                    fetches = (fetches, query)
            if not util.reduce_all(
                predicate=(lambda x: isinstance(x, str) and x.endswith('-output:0')), xs=fetches
            ):
                raise TensorforceError.unexpected()

            # TensorFlow session call
            fetched = self.monitored_session.run(fetches=fetches, feed_dict=feed_dict)

            return fetched
Ejemplo n.º 3
0
    def pretrain(self, directory, num_iterations, num_traces=1, num_updates=1):
        """
        Pretrain from experience traces.

        Args:
            directory (path): Directory with experience traces, e.g. obtained via recorder; episode
                length has to be consistent with agent configuration
                (<span style="color:#C00000"><b>required</b></span>).
            num_iterations (int > 0): Number of iterations consisting of loading new traces and
                performing multiple updates
                (<span style="color:#C00000"><b>required</b></span>).
            num_traces (int > 0): Number of traces to load per iteration; has to at least satisfy
                the update batch size
                (<span style="color:#00C000"><b>default</b></span>: 1).
            num_updates (int > 0): Number of updates per iteration
                (<span style="color:#00C000"><b>default</b></span>: 1).
        """
        if not os.path.isdir(directory):
            raise TensorforceError.unexpected()
        files = sorted(
            os.path.join(directory, f) for f in os.listdir(directory)
            if os.path.isfile(os.path.join(directory, f))
            and f.startswith('trace-'))
        indices = list(range(len(files)))

        for _ in range(num_iterations):
            shuffle(indices)
            if num_traces is None:
                selection = indices
            else:
                selection = indices[:num_traces]

            states = OrderedDict(((name, list()) for name in self.states_spec))
            for name, spec in self.actions_spec.items():
                if spec['type'] == 'int':
                    states[name + '_mask'] = list()
            actions = OrderedDict(
                ((name, list()) for name in self.actions_spec))
            terminal = list()
            reward = list()
            for index in selection:
                trace = np.load(files[index])
                for name in states:
                    states[name].append(trace[name])
                for name in actions:
                    actions[name].append(trace[name])
                terminal.append(trace['terminal'])
                reward.append(trace['reward'])

            states = util.fmap(function=np.concatenate, xs=states, depth=1)
            actions = util.fmap(function=np.concatenate, xs=actions, depth=1)
            terminal = np.concatenate(terminal)
            reward = np.concatenate(reward)

            self.experience(states=states,
                            actions=actions,
                            terminal=terminal,
                            reward=reward)
            for _ in range(num_updates):
                self.update()
Ejemplo n.º 4
0
    def tf_step(self,
                variables,
                arguments,
                fn_loss,
                fn_initial_gradients=None,
                **kwargs):
        arguments = util.fmap(function=tf.stop_gradient, xs=arguments)
        loss = fn_loss(**arguments)

        # Force loss value and attached control flow to be computed.
        with tf.control_dependencies(control_inputs=(loss, )):
            # Trivial operation to enforce control dependency
            previous_variables = util.fmap(function=util.identity_operation,
                                           xs=variables)

        # Get variables before update.
        with tf.control_dependencies(control_inputs=previous_variables):
            # applied = self.optimizer.minimize(loss=loss, var_list=variables)
            # grads_and_vars = self.optimizer.compute_gradients(loss=loss, var_list=variables)
            # gradients, variables = zip(*grads_and_vars)
            if fn_initial_gradients is None:
                initial_gradients = None
            else:
                initial_gradients = fn_initial_gradients(**arguments)
                initial_gradients = tf.stop_gradient(input=initial_gradients)

            gradients = tf.gradients(ys=loss,
                                     xs=variables,
                                     grad_ys=initial_gradients)
            assertions = [
                tf.debugging.assert_all_finite(
                    x=gradient, message="Finite gradients check.")
                for gradient in gradients
            ]

        with tf.control_dependencies(control_inputs=assertions):
            gradient_norm_clipping = self.gradient_norm_clipping.value()
            gradients, gradient_norm = tf.clip_by_global_norm(
                t_list=gradients, clip_norm=gradient_norm_clipping)
            gradients = self.add_summary(label='update-norm',
                                         name='gradient-norm-unclipped',
                                         tensor=gradient_norm,
                                         pass_tensors=gradients)

            applied = self.optimizer.apply_gradients(
                grads_and_vars=zip(gradients, variables))

        # Return deltas after actually having change the variables.
        with tf.control_dependencies(control_inputs=(applied, )):
            return [
                variable - previous_variable for variable, previous_variable in
                zip(variables, previous_variables)
            ]
Ejemplo n.º 5
0
        def apply_step():
            # lambda = sqrt(c' / c)
            lagrange_multiplier = tf.sqrt(x=(constant / learning_rate))

            # delta = delta' / lambda
            estimated_deltas = [
                delta / lagrange_multiplier for delta in deltas
            ]

            # improvement = grad(loss) * delta  (= loss_new - loss_old)
            estimated_improvement = tf.add_n(inputs=[
                tf.reduce_sum(input_tensor=(grad * delta))
                for grad, delta in zip(loss_gradients, estimated_deltas)
            ])

            # Apply natural gradient improvement.
            applied = self.apply_step(variables=variables,
                                      deltas=estimated_deltas)

            with tf.control_dependencies(control_inputs=(applied, )):
                # Trivial operation to enforce control dependency
                estimated_delta = util.fmap(function=util.identity_operation,
                                            xs=estimated_deltas)
                if return_estimated_improvement:
                    return estimated_delta, estimated_improvement
                else:
                    return estimated_delta
Ejemplo n.º 6
0
    def remote_receive(cls, connection):
        str_num_bytes = connection.recv(8)
        if len(str_num_bytes) != 8:
            raise TensorforceError.unexpected()
        num_bytes = int(str_num_bytes.decode())
        str_function = b''
        for n in range(num_bytes // cls.MAX_BYTES):
            str_function += connection.recv(cls.MAX_BYTES)
            if len(str_function) != n * cls.MAX_BYTES:
                raise TensorforceError.unexpected()
        str_function += connection.recv(num_bytes % cls.MAX_BYTES)
        if len(str_function) != num_bytes:
            raise TensorforceError.unexpected()
        function = str_function.decode()

        str_num_bytes = connection.recv(8)
        if len(str_num_bytes) != 8:
            raise TensorforceError.unexpected()
        num_bytes = int(str_num_bytes.decode())
        str_kwargs = b''
        for n in range(num_bytes // cls.MAX_BYTES):
            str_kwargs += connection.recv(cls.MAX_BYTES)
            if len(str_kwargs) != n * cls.MAX_BYTES:
                raise TensorforceError.unexpected()
        str_kwargs += connection.recv(num_bytes % cls.MAX_BYTES)
        if len(str_kwargs) != num_bytes:
            raise TensorforceError.unexpected()
        kwargs = msgpack.unpackb(packed=str_kwargs)
        decode = (lambda x: x.decode() if isinstance(x, bytes) else x)
        kwargs = util.fmap(function=decode, xs=kwargs, map_keys=True)

        return function, kwargs
Ejemplo n.º 7
0
    def tf_step(self, variables, arguments, **kwargs):
        """
        Creates the TensorFlow operations for performing an optimization step.

        Args:
            variables: List of variables to optimize.
            arguments: Dict of arguments for callables, like fn_loss.
            **kwargs: Additional arguments passed on to the internal optimizer.

        Returns:
            List of delta tensors corresponding to the updates for each optimized variable.
        """
        # Get some (batched) argument to determine batch size.
        arguments_iter = iter(arguments.values())
        some_argument = next(arguments_iter)

        try:
            while not isinstance(some_argument, tf.Tensor) or util.rank(
                    x=some_argument) == 0:
                if isinstance(some_argument, dict):
                    if some_argument:
                        arguments_iter = iter(some_argument.values())
                    some_argument = next(arguments_iter)
                elif isinstance(some_argument, list):
                    if some_argument:
                        arguments_iter = iter(some_argument)
                    some_argument = next(arguments_iter)
                elif some_argument is None or util.rank(x=some_argument) == 0:
                    # Non-batched argument
                    some_argument = next(arguments_iter)
                else:
                    raise TensorforceError("Invalid argument type.")
        except StopIteration:
            raise TensorforceError("Invalid argument type.")

        batch_size = tf.shape(input=some_argument)[0]
        fraction = self.fraction.value()
        num_samples = fraction * tf.cast(x=batch_size,
                                         dtype=util.tf_dtype('float'))
        one = tf.constant(value=1, dtype=util.tf_dtype('int'))
        num_samples = tf.maximum(x=tf.cast(x=num_samples,
                                           dtype=util.tf_dtype('int')),
                                 y=one)
        indices = tf.random.uniform(shape=(num_samples, ),
                                    maxval=batch_size,
                                    dtype=tf.int32)

        function = (lambda x: tf.gather(params=x, indices=indices))
        subsampled_arguments = util.fmap(function=function, xs=arguments)

        return self.optimizer.step(variables=variables,
                                   arguments=subsampled_arguments,
                                   **kwargs)
Ejemplo n.º 8
0
    def tf_kl_divergences(self, states, internals, auxiliaries, other=None):
        parameters = self.kldiv_reference(states=states,
                                          internals=internals,
                                          auxiliaries=auxiliaries)

        if other is None:
            other = util.fmap(function=tf.stop_gradient, xs=parameters)
        elif isinstance(other, ParametrizedDistributions):
            other = other.kldiv_reference(states=states,
                                          internals=internals,
                                          auxiliaries=auxiliaries)
            other = util.fmap(function=tf.stop_gradient, xs=other)
        elif isinstance(other, dict):
            if any(name not in other for name in self.actions_spec):
                raise TensorforceError.unexpected()
        else:
            raise TensorforceError.unexpected()

        kl_divergences = OrderedDict()
        for name, distribution in self.distributions.items():
            kl_divergences[name] = distribution.kl_divergence(
                parameters1=parameters[name], parameters2=other[name])

        return kl_divergences
Ejemplo n.º 9
0
        def apply_sync():
            update_weight = self.update_weight.value()
            deltas = list()
            for source_variable, target_variable in zip(
                    source_variables, variables):
                delta = update_weight * (source_variable - target_variable)
                deltas.append(delta)

            applied = self.apply_step(variables=variables, deltas=deltas)
            last_sync_updated = self.last_sync.assign(value=timestep)

            with tf.control_dependencies(control_inputs=(applied,
                                                         last_sync_updated)):
                # Trivial operation to enforce control dependency
                return util.fmap(function=util.identity_operation, xs=deltas)
Ejemplo n.º 10
0
    def tf_step(self, variables, arguments, **kwargs):
        # Get some (batched) argument to determine batch size.
        arguments_iter = iter(arguments.values())
        some_argument = next(arguments_iter)

        try:
            while not isinstance(some_argument, tf.Tensor) or util.rank(x=some_argument) == 0:
                if isinstance(some_argument, dict):
                    if some_argument:
                        arguments_iter = iter(some_argument.values())
                    some_argument = next(arguments_iter)
                elif isinstance(some_argument, list):
                    if some_argument:
                        arguments_iter = iter(some_argument)
                    some_argument = next(arguments_iter)
                elif some_argument is None or util.rank(x=some_argument) == 0:
                    # Non-batched argument
                    some_argument = next(arguments_iter)
                else:
                    raise TensorforceError("Invalid argument type.")
        except StopIteration:
            raise TensorforceError("Invalid argument type.")

        if util.tf_dtype(dtype='int') in (tf.int32, tf.int64):
            batch_size = tf.shape(input=some_argument, out_type=util.tf_dtype(dtype='int'))[0]
        else:
            batch_size = tf.dtypes.cast(
                x=tf.shape(input=some_argument)[0], dtype=util.tf_dtype(dtype='int')
            )
        fraction = self.fraction.value()
        num_samples = fraction * tf.dtypes.cast(x=batch_size, dtype=util.tf_dtype('float'))
        num_samples = tf.dtypes.cast(x=num_samples, dtype=util.tf_dtype('int'))
        one = tf.constant(value=1, dtype=util.tf_dtype('int'))
        num_samples = tf.maximum(x=num_samples, y=one)
        indices = tf.random.uniform(
            shape=(num_samples,), maxval=batch_size, dtype=util.tf_dtype(dtype='int')
        )

        function = (lambda x: tf.gather(params=x, indices=indices))
        subsampled_arguments = util.fmap(function=function, xs=arguments)

        return self.optimizer.step(variables=variables, arguments=subsampled_arguments, **kwargs)
            def optimize():
                if self.update_unit == 'timesteps':
                    # Timestep-based batch
                    batch = self.memory.retrieve_timesteps(n=batch_size)
                elif self.update_unit == 'episodes':
                    # Episode-based batch
                    batch = self.memory.retrieve_episodes(n=batch_size)
                elif self.update_unit == 'sequences':
                    # Timestep-sequence-based batch
                    batch = self.memory.retrieve_sequences(
                        n=batch_size, sequence_length=sequence_length
                    )

                # Do not calculate gradients for memory-internal operations.
                batch = util.fmap(function=tf.stop_gradient, xs=batch)
                Module.update_tensors(
                    update=tf.constant(value=True, dtype=util.tf_dtype(dtype='bool'))
                )
                optimized = self.optimization(**batch)

                return optimized
Ejemplo n.º 12
0
    def tf_reset(self):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        if not self.return_overwritten:
            # Reset buffer index
            assignment = self.buffer_index.assign(value=zero, read_value=False)

            # Return no-op
            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()

        # Overwritten buffer indices
        num_values = tf.minimum(x=self.buffer_index, y=capacity)
        indices = tf.range(start=(self.buffer_index - num_values),
                           limit=self.buffer_index)
        indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        values = OrderedDict()
        for name, buffer in self.buffers.items():
            if util.is_nested(name=name):
                values[name] = OrderedDict()
                for inner_name, buffer in buffer.items():
                    values[name][inner_name] = tf.gather(params=buffer,
                                                         indices=indices)
            else:
                values[name] = tf.gather(params=buffer, indices=indices)

        # Reset buffer index
        with tf.control_dependencies(control_inputs=util.flatten(xs=values)):
            assignment = self.buffer_index.assign(value=zero, read_value=False)

        # Return overwritten values
        with tf.control_dependencies(control_inputs=(assignment, )):
            return util.fmap(function=util.identity_operation, xs=values)
Ejemplo n.º 13
0
    def proxy_receive(cls, connection):
        str_success = connection.recv(1)
        if len(str_success) != 1:
            raise TensorforceError.unexpected()
        success = bool(str_success)

        str_num_bytes = connection.recv(8)
        if len(str_num_bytes) != 8:
            raise TensorforceError.unexpected()
        num_bytes = int(str_num_bytes.decode())
        str_result = b''
        for n in range(num_bytes // cls.MAX_BYTES):
            str_result += connection.recv(cls.MAX_BYTES)
            if len(str_result) != n * cls.MAX_BYTES:
                raise TensorforceError.unexpected()
        str_result += connection.recv(num_bytes % cls.MAX_BYTES)
        if len(str_result) != num_bytes:
            raise TensorforceError.unexpected()
        result = msgpack.unpackb(packed=str_result)
        decode = (lambda x: x.decode() if isinstance(x, bytes) else x)
        result = util.fmap(function=decode, xs=result, map_keys=True)

        return success, result
Ejemplo n.º 14
0
    def tf_step(self, variables, **kwargs):
        deltas = self.optimizer.step(variables=variables, **kwargs)

        with tf.control_dependencies(control_inputs=deltas):
            threshold = self.threshold.value()
            if self.mode == 'global_norm':
                clipped_deltas, update_norm = tf.clip_by_global_norm(
                    t_list=deltas, clip_norm=threshold)
            else:
                update_norm = tf.linalg.global_norm(t_list=deltas)
                clipped_deltas = list()
                for delta in deltas:
                    if self.mode == 'norm':
                        clipped_delta = tf.clip_by_norm(t=delta,
                                                        clip_norm=threshold)
                    elif self.mode == 'value':
                        clipped_delta = tf.clip_by_value(
                            t=delta,
                            clip_value_min=-threshold,
                            clip_value_max=threshold)
                    clipped_deltas.append(clipped_delta)

            clipped_deltas = self.add_summary(label='update-norm',
                                              name='update-norm-unclipped',
                                              tensor=update_norm,
                                              pass_tensors=clipped_deltas)

            exceeding_deltas = list()
            for delta, clipped_delta in zip(deltas, clipped_deltas):
                exceeding_deltas.append(clipped_delta - delta)

        applied = self.apply_step(variables=variables, deltas=exceeding_deltas)

        with tf.control_dependencies(control_inputs=(applied, )):
            return util.fmap(function=util.identity_operation,
                             xs=clipped_deltas)
Ejemplo n.º 15
0
    def tf_enqueue(self, **values):
        # Constants
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))

        # Get number of values
        for value in values.values():
            if not isinstance(value, dict):
                break
            elif len(value) > 0:
                value = next(iter(value.values()))
                break
        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            num_values = tf.shape(input=value,
                                  out_type=util.tf_dtype(dtype='long'))[0]
        else:
            num_values = tf.dtypes.cast(x=tf.shape(input=value)[0],
                                        dtype=util.tf_dtype(dtype='long'))

        # Check whether instances fit into buffer
        assertion = tf.debugging.assert_less_equal(x=num_values, y=capacity)

        if self.return_overwritten:
            # Overwritten buffer indices
            with tf.control_dependencies(control_inputs=(assertion, )):
                start = tf.maximum(x=self.buffer_index, y=capacity)
                limit = tf.maximum(x=(self.buffer_index + num_values),
                                   y=capacity)
                num_overwritten = limit - start
                indices = tf.range(start=start, limit=limit)
                indices = tf.math.mod(x=indices, y=capacity)

            # Get overwritten values
            with tf.control_dependencies(control_inputs=(indices, )):
                overwritten_values = OrderedDict()
                for name, buffer in self.buffers.items():
                    if util.is_nested(name=name):
                        overwritten_values[name] = OrderedDict()
                        for inner_name, buffer in buffer.items():
                            overwritten_values[name][inner_name] = tf.gather(
                                params=buffer, indices=indices)
                    else:
                        overwritten_values[name] = tf.gather(params=buffer,
                                                             indices=indices)

        else:
            overwritten_values = (assertion, )

        # Buffer indices to (over)write
        with tf.control_dependencies(control_inputs=util.flatten(
                xs=overwritten_values)):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_values))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)

        # Write new values
        with tf.control_dependencies(control_inputs=(indices, )):
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    assignment = buffer.scatter_nd_update(indices=indices,
                                                          updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_values,
                                                      read_value=False)

        # Return overwritten values or no-op
        with tf.control_dependencies(control_inputs=(assignment, )):
            if self.return_overwritten:
                any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
                overwritten_values = util.fmap(
                    function=util.identity_operation, xs=overwritten_values)
                return any_overwritten, overwritten_values
            else:
                return util.no_operation()
Ejemplo n.º 16
0
    def tf_step(self, variables, arguments, **kwargs):
        # # Get some (batched) argument to determine batch size.
        # arguments_iter = iter(arguments.values())
        # some_argument = next(arguments_iter)

        # try:
        #     while not isinstance(some_argument, tf.Tensor) or util.rank(x=some_argument) == 0:
        #         if isinstance(some_argument, dict):
        #             if some_argument:
        #                 arguments_iter = iter(some_argument.values())
        #             some_argument = next(arguments_iter)
        #         elif isinstance(some_argument, list):
        #             if some_argument:
        #                 arguments_iter = iter(some_argument)
        #             some_argument = next(arguments_iter)
        #         elif some_argument is None or util.rank(x=some_argument) == 0:
        #             # Non-batched argument
        #             some_argument = next(arguments_iter)
        #         else:
        #             raise TensorforceError("Invalid argument type.")
        # except StopIteration:
        #     raise TensorforceError("Invalid argument type.")

        some_argument = arguments['reward']

        if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
            batch_size = tf.shape(input=some_argument,
                                  out_type=util.tf_dtype(dtype='long'))[0]
        else:
            batch_size = tf.dtypes.cast(x=tf.shape(input=some_argument)[0],
                                        dtype=util.tf_dtype(dtype='long'))
        fraction = self.fraction.value()
        num_samples = fraction * tf.dtypes.cast(x=batch_size,
                                                dtype=util.tf_dtype('float'))
        num_samples = tf.dtypes.cast(x=num_samples,
                                     dtype=util.tf_dtype('long'))
        one = tf.constant(value=1, dtype=util.tf_dtype('long'))
        num_samples = tf.maximum(x=num_samples, y=one)
        indices = tf.random.uniform(shape=(num_samples, ),
                                    maxval=batch_size,
                                    dtype=util.tf_dtype(dtype='long'))
        states = arguments.pop('states')
        function = (lambda x: tf.gather(params=x, indices=indices))
        subsampled_arguments = util.fmap(function=function, xs=arguments)

        dependency_starts = Module.retrieve_tensor(name='dependency_starts')
        dependency_lengths = Module.retrieve_tensor(name='dependency_lengths')
        subsampled_starts = tf.gather(params=dependency_starts,
                                      indices=indices)
        subsampled_lengths = tf.gather(params=dependency_lengths,
                                       indices=indices)
        trivial_dependencies = tf.reduce_all(input_tensor=tf.math.equal(
            x=dependency_lengths, y=one),
                                             axis=0)

        def dependency_state_indices():
            fold = (lambda acc, args: tf.concat(values=(
                acc, tf.range(start=args[0], limit=(args[0] + args[1]))),
                                                axis=0))
            return tf.foldl(fn=fold,
                            elems=(subsampled_starts, subsampled_lengths),
                            initializer=indices[:0],
                            parallel_iterations=10,
                            back_prop=False,
                            swap_memory=False)

        states_indices = self.cond(pred=trivial_dependencies,
                                   true_fn=(lambda: indices),
                                   false_fn=dependency_state_indices)
        function = (lambda x: tf.gather(params=x, indices=states_indices))
        subsampled_arguments['states'] = util.fmap(function=function,
                                                   xs=states)

        subsampled_starts = tf.math.cumsum(x=subsampled_lengths,
                                           exclusive=True)
        Module.update_tensors(dependency_starts=subsampled_starts,
                              dependency_lengths=subsampled_lengths)

        deltas = self.optimizer.step(variables=variables,
                                     arguments=subsampled_arguments,
                                     **kwargs)

        Module.update_tensors(dependency_starts=dependency_starts,
                              dependency_lengths=dependency_lengths)

        return deltas
Ejemplo n.º 17
0
    def experience(
        self, states, actions, terminal, reward, internals=None, query=None, **kwargs
    ):
        """
        Feed experience traces.

        Args:
            states (dict[array[state]]): Dictionary containing arrays of states
                (<span style="color:#C00000"><b>required</b></span>).
            actions (dict[array[action]]): Dictionary containing arrays of actions
                (<span style="color:#C00000"><b>required</b></span>).
            terminal (array[bool]): Array of terminals
                (<span style="color:#C00000"><b>required</b></span>).
            reward (array[float]): Array of rewards
                (<span style="color:#C00000"><b>required</b></span>).
            internals (dict[state]): Dictionary containing arrays of internal agent states
                (<span style="color:#00C000"><b>default</b></span>: no internal states).
            query (list[str]): Names of tensors to retrieve
                (<span style="color:#00C000"><b>default</b></span>: none).
            kwargs: Additional input values, for instance, for dynamic hyperparameters.
        """
        assert (self.buffer_indices == 0).all()
        assert util.reduce_all(predicate=util.not_nan_inf, xs=states)
        assert internals is None or util.reduce_all(predicate=util.not_nan_inf, xs=internals)
        assert util.reduce_all(predicate=util.not_nan_inf, xs=actions)
        assert util.reduce_all(predicate=util.not_nan_inf, xs=reward)

        # Auxiliaries
        auxiliaries = OrderedDict()
        if isinstance(states, dict):
            for name, spec in self.actions_spec.items():
                if spec['type'] == 'int' and name + '_mask' in states:
                    auxiliaries[name + '_mask'] = np.asarray(states.pop(name + '_mask'))
        auxiliaries = util.fmap(function=np.asarray, xs=auxiliaries, depth=1)

        # Normalize states/actions dictionaries
        states = util.normalize_values(
            value_type='state', values=states, values_spec=self.states_spec
        )
        actions = util.normalize_values(
            value_type='action', values=actions, values_spec=self.actions_spec
        )
        if internals is None:
            internals = OrderedDict()

        if isinstance(terminal, (bool, int)):
            states = util.fmap(function=(lambda x: [x]), xs=states, depth=1)
            internals = util.fmap(function=(lambda x: [x]), xs=internals, depth=1)
            auxiliaries = util.fmap(function=(lambda x: [x]), xs=auxiliaries, depth=1)
            actions = util.fmap(function=(lambda x: [x]), xs=actions, depth=1)
            terminal = [terminal]
            reward = [reward]

        states = util.fmap(function=np.asarray, xs=states, depth=1)
        internals = util.fmap(function=np.asarray, xs=internals, depth=1)
        auxiliaries = util.fmap(function=np.asarray, xs=auxiliaries, depth=1)
        actions = util.fmap(function=np.asarray, xs=actions, depth=1)

        if isinstance(terminal, np.ndarray):
            if terminal.dtype is util.np_dtype(dtype='bool'):
                zeros = np.zeros_like(terminal, dtype=util.np_dtype(dtype='long'))
                ones = np.ones_like(terminal, dtype=util.np_dtype(dtype='long'))
                terminal = np.where(terminal, ones, zeros)
        else:
            terminal = np.asarray([int(x) if isinstance(x, bool) else x for x in terminal])
        reward = np.asarray(reward)

        # Batch experiences split into episodes and at most size buffer_observe
        last = 0
        for index in range(1, len(terminal) + 1):
            if terminal[index - 1] == 0 and index - last < self.experience_size:
                continue

            # Include terminal in batch if possible
            if index < len(terminal) and terminal[index - 1] == 0 and terminal[index] > 0 and \
                    index - last < self.experience_size:
                index += 1

            function = (lambda x: x[last: index])
            states_batch = util.fmap(function=function, xs=states, depth=1)
            internals_batch = util.fmap(function=function, xs=internals, depth=1)
            auxiliaries_batch = util.fmap(function=function, xs=auxiliaries, depth=1)
            actions_batch = util.fmap(function=function, xs=actions, depth=1)
            terminal_batch = terminal[last: index]
            reward_batch = reward[last: index]
            last = index

            # Model.experience()
            if query is None:
                self.timesteps, self.episodes, self.updates = self.model.experience(
                    states=states_batch, internals=internals_batch,
                    auxiliaries=auxiliaries_batch, actions=actions_batch, terminal=terminal_batch,
                    reward=reward_batch, **kwargs
                )

            else:
                self.timesteps, self.episodes, self.updates, queried = self.model.experience(
                    states=states_batch, internals=internals_batch,
                    auxiliaries=auxiliaries_batch, actions=actions_batch, terminal=terminal_batch,
                    reward=reward_batch, query=query, **kwargs
                )

        if query is not None:
            return queried
Ejemplo n.º 18
0
    def tf_step(self,
                variables,
                arguments,
                fn_loss,
                fn_kl_divergence,
                return_estimated_improvement=False,
                **kwargs):
        # Optimize: argmin(w) loss(w + delta) such that kldiv(P(w) || P(w + delta)) = learning_rate
        # For more details, see our blogpost:
        # https://reinforce.io/blog/end-to-end-computation-graphs-for-reinforcement-learning/

        # Calculates the product x * F of a given vector x with the fisher matrix F.
        # Incorporating the product prevents having to calculate the entire matrix explicitly.
        def fisher_matrix_product(deltas):
            # Gradient is not propagated through solver.
            deltas = [tf.stop_gradient(input=delta) for delta in deltas]

            # kldiv
            kldiv = fn_kl_divergence(**arguments)

            # grad(kldiv)
            kldiv_gradients = [
                tf.convert_to_tensor(value=grad)
                for grad in tf.gradients(ys=kldiv, xs=variables)
            ]

            # delta' * grad(kldiv)
            delta_kldiv_gradients = tf.add_n(inputs=[
                tf.reduce_sum(input_tensor=(delta * grad))
                for delta, grad in zip(deltas, kldiv_gradients)
            ])

            # [delta' * F] = grad(delta' * grad(kldiv))
            return [
                tf.convert_to_tensor(value=grad)
                for grad in tf.gradients(ys=delta_kldiv_gradients,
                                         xs=variables)
            ]

        # loss
        arguments = util.fmap(function=tf.stop_gradient, xs=arguments)
        loss = fn_loss(**arguments)

        # grad(loss)
        loss_gradients = tf.gradients(ys=loss, xs=variables)

        # Solve the following system for delta' via the conjugate gradient solver.
        # [delta' * F] * delta' = -grad(loss)
        # --> delta'  (= lambda * delta)
        deltas = self.solver.solve(fn_x=fisher_matrix_product,
                                   x_init=None,
                                   b=[-grad for grad in loss_gradients])

        # delta' * F
        delta_fisher_matrix_product = fisher_matrix_product(deltas=deltas)

        # c' = 0.5 * delta' * F * delta'  (= lambda * c)
        # TODO: Why constant and hence KL-divergence sometimes negative?
        half = tf.constant(value=0.5, dtype=util.tf_dtype(dtype='float'))
        constant = half * tf.add_n(inputs=[
            tf.reduce_sum(input_tensor=(delta_F * delta))
            for delta_F, delta in zip(delta_fisher_matrix_product, deltas)
        ])

        learning_rate = self.learning_rate.value()

        # Zero step if constant <= 0
        def no_step():
            zero_deltas = [tf.zeros_like(input=delta) for delta in deltas]
            if return_estimated_improvement:
                return zero_deltas, tf.constant(
                    value=0.0, dtype=util.tf_dtype(dtype='float'))
            else:
                return zero_deltas

        # Natural gradient step if constant > 0
        def apply_step():
            # lambda = sqrt(c' / c)
            lagrange_multiplier = tf.sqrt(x=(constant / learning_rate))

            # delta = delta' / lambda
            estimated_deltas = [
                delta / lagrange_multiplier for delta in deltas
            ]

            # improvement = grad(loss) * delta  (= loss_new - loss_old)
            estimated_improvement = tf.add_n(inputs=[
                tf.reduce_sum(input_tensor=(grad * delta))
                for grad, delta in zip(loss_gradients, estimated_deltas)
            ])

            # Apply natural gradient improvement.
            applied = self.apply_step(variables=variables,
                                      deltas=estimated_deltas)

            with tf.control_dependencies(control_inputs=(applied, )):
                # Trivial operation to enforce control dependency
                estimated_delta = util.fmap(function=util.identity_operation,
                                            xs=estimated_deltas)
                if return_estimated_improvement:
                    return estimated_delta, estimated_improvement
                else:
                    return estimated_delta

        # Natural gradient step only works if constant > 0
        skip_step = constant > tf.constant(value=0.0,
                                           dtype=util.tf_dtype(dtype='float'))
        return self.cond(pred=skip_step, true_fn=no_step, false_fn=apply_step)
    def act(self,
            states,
            parallel=0,
            deterministic=False,
            independent=False,
            query=None,
            **kwargs):
        """
        Return action(s) for given state(s). States preprocessing and exploration are applied if
        configured accordingly.

        Args:
            states (any): One state (usually a value tuple) or dict of states if multiple states are expected.
            deterministic (bool): If true, no exploration and sampling is applied.
            independent (bool): If true, action is not followed by observe (and hence not included
                in updates).
            fetch_tensors (list): Optional String of named tensors to fetch
        Returns:
            Scalar value of the action or dict of multiple actions the agent wants to execute.
            (fetched_tensors) Optional dict() with named tensors fetched
        """
        # self.current_internals = self.next_internals

        # Normalize states dictionary
        states = util.normalize_values(value_type='state',
                                       values=states,
                                       values_spec=self.states_spec)

        # Batch states
        states = util.fmap(function=(lambda x: [x]), xs=states)

        # Model.act()
        if query is None:
            actions, self.timestep = self.model.act(
                states=states,
                parallel=parallel,
                deterministic=deterministic,
                independent=independent,
                **kwargs)

        else:
            actions, self.timestep, query = self.model.act(
                states=states,
                parallel=parallel,
                deterministic=deterministic,
                independent=independent,
                query=query,
                **kwargs)

        # Unbatch actions
        actions = util.fmap(function=(lambda x: x[0]), xs=actions)

        # Reverse normalized actions dictionary
        actions = util.unpack_values(value_type='action',
                                     values=actions,
                                     values_spec=self.actions_spec)

        # if independent, return processed state as well?

        if query is None:
            return actions
        else:
            return actions, query
Ejemplo n.º 20
0
    def tf_enqueue(self,
                   states,
                   internals,
                   auxiliaries,
                   actions,
                   terminal,
                   reward,
                   baseline=None):
        # Constants and parameters
        zero = tf.constant(value=0, dtype=util.tf_dtype(dtype='long'))
        one = tf.constant(value=1, dtype=util.tf_dtype(dtype='long'))
        capacity = tf.constant(value=self.capacity,
                               dtype=util.tf_dtype(dtype='long'))
        horizon = self.horizon.value()
        discount = self.discount.value()

        assertions = list()
        # Check whether horizon at most capacity
        assertions.append(
            tf.debugging.assert_less_equal(
                x=horizon,
                y=capacity,
                message=
                "Estimator capacity has to be at least the same as the estimation horizon."
            ))
        # Check whether at most one terminal
        assertions.append(
            tf.debugging.assert_less_equal(
                x=tf.math.count_nonzero(input=terminal,
                                        dtype=util.tf_dtype(dtype='long')),
                y=one,
                message="Timesteps contain more than one terminal."))
        # Check whether, if any, last value is terminal
        assertions.append(
            tf.debugging.assert_equal(
                x=tf.reduce_any(
                    input_tensor=tf.math.greater(x=terminal, y=zero)),
                y=tf.math.greater(x=terminal[-1], y=zero),
                message="Terminal is not the last timestep."))

        # Get number of overwritten values
        with tf.control_dependencies(control_inputs=assertions):
            if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                num_values = tf.shape(input=terminal,
                                      out_type=util.tf_dtype(dtype='long'))[0]
            else:
                num_values = tf.dtypes.cast(x=tf.shape(input=terminal)[0],
                                            dtype=util.tf_dtype(dtype='long'))
            overwritten_start = tf.maximum(x=self.buffer_index, y=capacity)
            overwritten_limit = tf.maximum(x=(self.buffer_index + num_values),
                                           y=capacity)
            num_overwritten = overwritten_limit - overwritten_start

        def update_overwritten_rewards():
            # Get relevant buffer rewards
            buffer_limit = self.buffer_index + tf.minimum(
                x=(num_overwritten + horizon), y=capacity)
            buffer_indices = tf.range(start=self.buffer_index,
                                      limit=buffer_limit)
            buffer_indices = tf.math.mod(x=buffer_indices, y=capacity)
            rewards = tf.gather(params=self.buffers['reward'],
                                indices=buffer_indices)

            # Get relevant values rewards
            values_limit = tf.maximum(x=(num_overwritten + horizon - capacity),
                                      y=zero)
            rewards = tf.concat(values=(rewards, reward[:values_limit]),
                                axis=0)

            # Horizon baseline value
            if self.estimate_horizon == 'early':
                assert baseline is not None
                # Baseline estimate
                buffer_indices = buffer_indices[horizon + one:]
                _states = OrderedDict()
                for name, buffer in self.buffers['states'].items():
                    state = tf.gather(params=buffer, indices=buffer_indices)
                    _states[name] = tf.concat(
                        values=(state, states[name][:values_limit + one]),
                        axis=0)
                _internals = OrderedDict()
                for name, buffer in self.buffers['internals'].items():
                    internal = tf.gather(params=buffer, indices=buffer_indices)
                    _internals[name] = tf.concat(
                        values=(internal,
                                internals[name][:values_limit + one]),
                        axis=0)
                _auxiliaries = OrderedDict()
                for name, buffer in self.buffers['auxiliaries'].items():
                    auxiliary = tf.gather(params=buffer,
                                          indices=buffer_indices)
                    _auxiliaries[name] = tf.concat(
                        values=(auxiliary,
                                auxiliaries[name][:values_limit + one]),
                        axis=0)

                # Dependency horizon
                # TODO: handle arbitrary non-optimization horizons!
                past_horizon = baseline.past_horizon(is_optimization=False)
                assertion = tf.debugging.assert_equal(
                    x=past_horizon,
                    y=zero,
                    message=
                    "Temporary: baseline cannot depend on previous states.")
                with tf.control_dependencies(control_inputs=(assertion, )):
                    some_state = next(iter(_states.values()))
                    if util.tf_dtype(dtype='long') in (tf.int32, tf.int64):
                        batch_size = tf.shape(
                            input=some_state,
                            out_type=util.tf_dtype(dtype='long'))[0]
                    else:
                        batch_size = tf.dtypes.cast(
                            x=tf.shape(input=some_state)[0],
                            dtype=util.tf_dtype(dtype='long'))
                    starts = tf.range(start=batch_size,
                                      dtype=util.tf_dtype(dtype='long'))
                    lengths = tf.ones(shape=(batch_size, ),
                                      dtype=util.tf_dtype(dtype='long'))
                    Module.update_tensors(dependency_starts=starts,
                                          dependency_lengths=lengths)

                if self.estimate_actions:
                    _actions = OrderedDict()
                    for name, buffer in self.buffers['actions'].items():
                        action = tf.gather(params=buffer,
                                           indices=buffer_indices)
                        _actions[name] = tf.concat(
                            values=(action, actions[name][:values_limit]),
                            axis=0)
                    horizon_estimate = baseline.actions_value(
                        states=_states,
                        internals=_internals,
                        auxiliaries=_auxiliaries,
                        actions=_actions)
                else:
                    horizon_estimate = baseline.states_value(
                        states=_states,
                        internals=_internals,
                        auxiliaries=_auxiliaries)

            else:
                # Zero estimate
                horizon_estimate = tf.zeros(shape=(num_overwritten, ),
                                            dtype=util.tf_dtype(dtype='float'))

            # Calculate discounted sum
            def cond(discounted_sum, horizon):
                return tf.math.greater_equal(x=horizon, y=zero)

            def body(discounted_sum, horizon):
                # discounted_sum = tf.compat.v1.Print(
                #     discounted_sum, (horizon, discounted_sum, rewards[horizon:]), summarize=10
                # )
                discounted_sum = discount * discounted_sum
                discounted_sum = discounted_sum + rewards[horizon:horizon +
                                                          num_overwritten]
                return discounted_sum, horizon - one

            discounted_sum, _ = self.while_loop(cond=cond,
                                                body=body,
                                                loop_vars=(horizon_estimate,
                                                           horizon),
                                                back_prop=False)

            assertions = [
                tf.debugging.assert_equal(x=tf.shape(input=horizon_estimate),
                                          y=tf.shape(input=discounted_sum),
                                          message="Estimation check."),
                tf.debugging.assert_equal(x=tf.shape(
                    input=rewards, out_type=util.tf_dtype(dtype='long'))[0],
                                          y=(horizon + num_overwritten),
                                          message="Estimation check.")
            ]

            # Overwrite buffer rewards
            with tf.control_dependencies(control_inputs=assertions):
                indices = tf.range(start=self.buffer_index,
                                   limit=(self.buffer_index + num_overwritten))
                indices = tf.math.mod(x=indices, y=capacity)
                indices = tf.expand_dims(input=indices, axis=1)

            assignment = self.buffers['reward'].scatter_nd_update(
                indices=indices, updates=discounted_sum)

            with tf.control_dependencies(control_inputs=(assignment, )):
                return util.no_operation()

        any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
        updated_rewards = self.cond(pred=any_overwritten,
                                    true_fn=update_overwritten_rewards,
                                    false_fn=util.no_operation)

        # Overwritten buffer indices
        with tf.control_dependencies(control_inputs=(updated_rewards, )):
            indices = tf.range(start=overwritten_start,
                               limit=overwritten_limit)
            indices = tf.math.mod(x=indices, y=capacity)

        # Get overwritten values
        with tf.control_dependencies(control_inputs=(indices, )):
            overwritten_values = OrderedDict()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    overwritten_values[name] = OrderedDict()
                    for inner_name, buffer in buffer.items():
                        overwritten_values[name][inner_name] = tf.gather(
                            params=buffer, indices=indices)
                else:
                    overwritten_values[name] = tf.gather(params=buffer,
                                                         indices=indices)

        # Buffer indices to (over)write
        with tf.control_dependencies(control_inputs=util.flatten(
                xs=overwritten_values)):
            indices = tf.range(start=self.buffer_index,
                               limit=(self.buffer_index + num_values))
            indices = tf.math.mod(x=indices, y=capacity)
            indices = tf.expand_dims(input=indices, axis=1)

        # Write new values
        with tf.control_dependencies(control_inputs=(indices, )):
            values = dict(states=states,
                          internals=internals,
                          auxiliaries=auxiliaries,
                          actions=actions,
                          terminal=terminal,
                          reward=reward)
            assignments = list()
            for name, buffer in self.buffers.items():
                if util.is_nested(name=name):
                    for inner_name, buffer in buffer.items():
                        assignment = buffer.scatter_nd_update(
                            indices=indices, updates=values[name][inner_name])
                        assignments.append(assignment)
                else:
                    assignment = buffer.scatter_nd_update(indices=indices,
                                                          updates=values[name])
                    assignments.append(assignment)

        # Increment buffer index
        with tf.control_dependencies(control_inputs=assignments):
            assignment = self.buffer_index.assign_add(delta=num_values,
                                                      read_value=False)

        # Return overwritten values or no-op
        with tf.control_dependencies(control_inputs=(assignment, )):
            any_overwritten = tf.math.greater(x=num_overwritten, y=zero)
            overwritten_values = util.fmap(function=util.identity_operation,
                                           xs=overwritten_values)
            return any_overwritten, overwritten_values
Ejemplo n.º 21
0
    def act(self,
            states,
            internals=None,
            parallel=0,
            independent=False,
            deterministic=False,
            evaluation=False,
            query=None,
            **kwargs):
        """
        Returns action(s) for the given state(s), needs to be followed by `observe(...)` unless independent mode set via `independent`/`evaluation`.

        Args:
            states (dict[state] | iter[dict[state]]): Dictionary containing state(s) to be acted on
                (<span style="color:#C00000"><b>required</b></span>).
            internals (dict[internal] | iter[dict[internal]]): Dictionary containing current
                internal agent state(s)
                (<span style="color:#C00000"><b>required</b></span> if independent mode).
            parallel (int | iter[int]): Parallel execution index
                (<span style="color:#00C000"><b>default</b></span>: 0).
            independent (bool): Whether act is not part of the main agent-environment interaction,
                and this call is thus not followed by observe
                (<span style="color:#00C000"><b>default</b></span>: false).
            deterministic (bool): Ff independent mode, whether to act deterministically, so no
                exploration and sampling
                (<span style="color:#00C000"><b>default</b></span>: false).
            evaluation (bool): Whether the agent is currently evaluated, implies independent and
                deterministic
                (<span style="color:#00C000"><b>default</b></span>: false).
            query (list[str]): Names of tensors to retrieve
                (<span style="color:#00C000"><b>default</b></span>: none).
            kwargs: Additional input values, for instance, for dynamic hyperparameters.

        Returns:
            dict[action] | iter[dict[action]], if independent mode dict[internal] |
            iter[dict[internal]], plus optional list[str]: Dictionary containing action(s),
            dictionary containing next internal agent state(s) if independent mode, plus queried
            tensor values if requested.
        """
        assert util.reduce_all(predicate=util.not_nan_inf, xs=states)

        if evaluation:
            if deterministic:
                raise TensorforceError.invalid(name='agent.act',
                                               argument='deterministic',
                                               condition='evaluation = true')
            if independent:
                raise TensorforceError.invalid(name='agent.act',
                                               argument='independent',
                                               condition='evaluation = true')
            deterministic = independent = True

        if not independent:
            if internals is not None:
                raise TensorforceError.invalid(name='agent.act',
                                               argument='internals',
                                               condition='independent = false')
            if deterministic:
                raise TensorforceError.invalid(name='agent.act',
                                               argument='deterministic',
                                               condition='independent = false')

        if independent:
            internals_is_none = (internals is None)
            if internals_is_none:
                internals = OrderedDict()

        # Batch states
        batched = (not isinstance(parallel, int))
        if batched:
            if len(parallel) == 0:
                raise TensorforceError.value(name='agent.act',
                                             argument='parallel',
                                             value=parallel,
                                             hint='zero-length')
            parallel = np.asarray(list(parallel))
            if isinstance(states[0], dict):
                states = OrderedDict(
                    ((name,
                      np.asarray(
                          [states[n][name] for n in range(len(parallel))]))
                     for name in states[0]))
            else:
                states = np.asarray(states)
            if independent:
                internals = OrderedDict(
                    ((name,
                      np.asarray(
                          [internals[n][name] for n in range(len(parallel))]))
                     for name in internals[0]))
        else:
            parallel = np.asarray([parallel])
            states = util.fmap(function=(lambda x: np.asarray([x])),
                               xs=states,
                               depth=int(isinstance(states, dict)))
            if independent:
                internals = util.fmap(function=(lambda x: np.asarray([x])),
                                      xs=internals,
                                      depth=1)

        if not independent and not all(self.timestep_completed[n]
                                       for n in parallel):
            raise TensorforceError(
                message="Calling agent.act must be preceded by agent.observe.")

        # Auxiliaries
        auxiliaries = OrderedDict()
        if isinstance(states, dict):
            states = dict(states)
            for name, spec in self.actions_spec.items():
                if spec['type'] == 'int' and name + '_mask' in states:
                    auxiliaries[name + '_mask'] = states.pop(name + '_mask')

        # Normalize states dictionary
        states = util.normalize_values(value_type='state',
                                       values=states,
                                       values_spec=self.states_spec)

        # Model.act()
        if independent:
            if query is None:
                actions, internals = self.model.independent_act(
                    states=states,
                    internals=internals,
                    auxiliaries=auxiliaries,
                    parallel=parallel,
                    deterministic=deterministic,
                    **kwargs)

            else:
                actions, internals, queried = self.model.independent_act(
                    states=states,
                    internals=internals,
                    auxiliaries=auxiliaries,
                    parallel=parallel,
                    deterministic=deterministic,
                    query=query,
                    **kwargs)

        else:
            if query is None:
                actions, self.timesteps = self.model.act(
                    states=states,
                    auxiliaries=auxiliaries,
                    parallel=parallel,
                    **kwargs)

            else:
                actions, self.timesteps, queried = self.model.act(
                    states=states,
                    auxiliaries=auxiliaries,
                    parallel=parallel,
                    query=query,
                    **kwargs)

        if not independent:
            for n in parallel:
                self.timestep_completed[n] = False

        if self.recorder_spec is not None and not independent and \
                self.episodes >= self.recorder_spec.get('start', 0):
            for n in range(len(parallel)):
                index = self.buffer_indices[parallel[n]]
                for name in self.states_spec:
                    self.states_buffers[name][parallel[n],
                                              index] = states[name][n]
                for name, spec in self.actions_spec.items():
                    self.actions_buffers[name][parallel[n],
                                               index] = actions[name][n]
                    if spec['type'] == 'int':
                        name = name + '_mask'
                        if name in auxiliaries:
                            self.states_buffers[name][
                                parallel[n], index] = auxiliaries[name][n]
                        else:
                            shape = (1, ) + spec['shape'] + (
                                spec['num_values'], )
                            self.states_buffers[name][parallel[n],
                                                      index] = np.full(
                                                          shape=shape,
                                                          fill_value=True,
                                                          dtype=util.np_dtype(
                                                              dtype='bool'))

        # Reverse normalized actions dictionary
        actions = util.unpack_values(value_type='action',
                                     values=actions,
                                     values_spec=self.actions_spec)

        # Unbatch actions
        if batched:
            if isinstance(actions, dict):
                actions = [
                    OrderedDict(((name, actions[name][n]) for name in actions))
                    for n in range(len(parallel))
                ]
        else:
            actions = util.fmap(function=(lambda x: x[0]),
                                xs=actions,
                                depth=int(isinstance(actions, dict)))
            if independent:
                internals = util.fmap(function=(lambda x: x[0]),
                                      xs=internals,
                                      depth=1)

        if independent and not internals_is_none:
            if query is None:
                return actions, internals
            else:
                return actions, internals, queried

        else:
            if query is None:
                return actions
            else:
                return actions, queried
Ejemplo n.º 22
0
    def observe(self, reward, terminal=False, parallel=0, query=None, **kwargs):
        """
        Observes reward and whether a terminal state is reached, needs to be preceded by
        `act(...)`.

        Args:
            reward (float): Reward
                (<span style="color:#C00000"><b>required</b></span>).
            terminal (bool | 0 | 1 | 2): Whether a terminal state is reached or 2 if the
                episode was aborted (<span style="color:#00C000"><b>default</b></span>: false).
            parallel (int): Parallel execution index
                (<span style="color:#00C000"><b>default</b></span>: 0).
            query (list[str]): Names of tensors to retrieve
                (<span style="color:#00C000"><b>default</b></span>: none).
            kwargs: Additional input values, for instance, for dynamic hyperparameters.

        Returns:
            (bool, optional list[str]): Whether an update was performed, plus queried tensor values
            if requested.
        """
        assert util.reduce_all(predicate=util.not_nan_inf, xs=reward)

        if query is not None and self.parallel_interactions > 1:
            raise TensorforceError.unexpected()

        if isinstance(terminal, bool):
            terminal = int(terminal)

        # Update terminal/reward buffer
        index = self.buffer_indices[parallel]
        self.terminal_buffers[parallel, index] = terminal
        self.reward_buffers[parallel, index] = reward
        index += 1

        if self.max_episode_timesteps is not None and index > self.max_episode_timesteps:
            raise TensorforceError.unexpected()

        if terminal > 0 or index == self.buffer_observe or query is not None:
            terminal = self.terminal_buffers[parallel, :index]
            reward = self.reward_buffers[parallel, :index]

            if self.recorder_spec is not None and \
                    self.episodes >= self.recorder_spec.get('start', 0):
                for name in self.states_spec:
                    self.record_states[name].append(
                        np.array(self.states_buffers[name][parallel, :index])
                    )
                for name, spec in self.actions_spec.items():
                    self.record_actions[name].append(
                        np.array(self.actions_buffers[name][parallel, :index])
                    )
                    if spec['type'] == 'int':
                        self.record_states[name + '_mask'].append(
                            np.array(self.states_buffers[name + '_mask'][parallel, :index])
                        )
                self.record_terminal.append(np.array(terminal))
                self.record_reward.append(np.array(reward))

                if terminal[-1] > 0:
                    self.num_episodes += 1

                    if self.num_episodes == self.recorder_spec.get('frequency', 1):
                        directory = self.recorder_spec['directory']
                        if os.path.isdir(directory):
                            files = sorted(
                                f for f in os.listdir(directory)
                                if os.path.isfile(os.path.join(directory, f))
                                and f.startswith('trace-')
                            )
                        else:
                            os.makedirs(directory)
                            files = list()
                        max_traces = self.recorder_spec.get('max-traces')
                        if max_traces is not None and len(files) > max_traces - 1:
                            for filename in files[:-max_traces + 1]:
                                filename = os.path.join(directory, filename)
                                os.remove(filename)

                        filename = 'trace-{}-{}.npz'.format(
                            self.episodes, time.strftime('%Y%m%d-%H%M%S')
                        )
                        filename = os.path.join(directory, filename)
                        self.record_states = util.fmap(
                            function=np.concatenate, xs=self.record_states, depth=1
                        )
                        self.record_actions = util.fmap(
                            function=np.concatenate, xs=self.record_actions, depth=1
                        )
                        self.record_terminal = np.concatenate(self.record_terminal)
                        self.record_reward = np.concatenate(self.record_reward)
                        np.savez_compressed(
                            filename, **self.record_states, **self.record_actions,
                            terminal=self.record_terminal, reward=self.record_reward
                        )
                        self.record_states = util.fmap(
                            function=(lambda x: list()), xs=self.record_states, depth=1
                        )
                        self.record_actions = util.fmap(
                            function=(lambda x: list()), xs=self.record_actions, depth=1
                        )
                        self.record_terminal = list()
                        self.record_reward = list()
                        self.num_episodes = 0

            # Model.observe()
            if query is None:
                updated, self.episodes, self.updates = self.model.observe(
                    terminal=terminal, reward=reward, parallel=parallel, **kwargs
                )

            else:
                updated, self.episodes, self.updates, queried = self.model.observe(
                    terminal=terminal, reward=reward, parallel=parallel, query=query, **kwargs
                )

            # Reset buffer index
            self.buffer_indices[parallel] = 0

        else:
            # Increment buffer index
            self.buffer_indices[parallel] = index
            updated = False

        if query is None:
            return updated
        else:
            return updated, queried
Ejemplo n.º 23
0
    def act(self,
            states,
            internals=None,
            parallel=0,
            independent=False,
            **kwargs):
        # Independent and internals
        is_internals_none = (internals is None)
        if independent:
            if parallel != 0:
                raise TensorforceError.invalid(name='Agent.act',
                                               argument='parallel',
                                               condition='independent is true')
            if is_internals_none and len(self.internals_spec) > 0:
                raise TensorforceError.required(
                    name='Agent.act',
                    argument='internals',
                    condition='independent is true')
        else:
            if not is_internals_none:
                raise TensorforceError.invalid(
                    name='Agent.act',
                    argument='internals',
                    condition='independent is false')

        # Process states input and infer batching structure
        states, batched, num_parallel, is_iter_of_dicts, input_type = self._process_states_input(
            states=states, function_name='Agent.act')

        if independent:
            # Independent mode: handle internals argument
            if is_internals_none:
                # Default input internals=None
                pass

            elif is_iter_of_dicts:
                # Input structure iter[dict[internal]]
                if not isinstance(internals, (tuple, list)):
                    raise TensorforceError.type(name='Agent.act',
                                                argument='internals',
                                                dtype=type(internals),
                                                hint='is not tuple/list')
                internals = [ArrayDict(internal) for internal in internals]
                internals = internals[0].fmap(
                    function=(lambda *xs: np.stack(xs, axis=0)),
                    zip_values=internals[1:])

            else:
                # Input structure dict[iter[internal]]
                if not isinstance(internals, dict):
                    raise TensorforceError.type(name='Agent.act',
                                                argument='internals',
                                                dtype=type(internals),
                                                hint='is not dict')
                internals = ArrayDict(internals)

            if not independent or not is_internals_none:
                # Expand inputs if not batched
                if not batched:
                    internals = internals.fmap(
                        function=(lambda x: np.expand_dims(x, axis=0)))

                # Check number of inputs
                for name, internal in internals.items():
                    if internal.shape[0] != num_parallel:
                        raise TensorforceError.value(
                            name='Agent.act',
                            argument='len(internals[{}])'.format(name),
                            value=internal.shape[0],
                            hint='!= len(states)')

        else:
            # Non-independent mode: handle parallel input
            if parallel == 0:
                # Default input parallel=0
                if batched:
                    assert num_parallel == self.parallel_interactions
                    parallel = np.asarray(list(range(num_parallel)))
                else:
                    parallel = np.asarray([parallel])

            elif batched:
                # Batched input
                parallel = np.asarray(parallel)

            else:
                # Expand input if not batched
                parallel = np.asarray([parallel])

            # Check number of inputs
            if parallel.shape[0] != num_parallel:
                raise TensorforceError.value(name='Agent.act',
                                             argument='len(parallel)',
                                             value=len(parallel),
                                             hint='!= len(states)')

        # If not independent, check whether previous timesteps were completed
        if not independent:
            if not self.timestep_completed[parallel].all():
                raise TensorforceError(
                    message=
                    "Calling agent.act must be preceded by agent.observe.")
            self.timestep_completed[parallel] = False

        # Buffer inputs for recording
        if self.recorder_spec is not None and not independent and \
                self.num_episodes >= self.recorder_spec.get('start', 0):
            for n in range(num_parallel):
                for name in self.states_spec:
                    self.buffers['states'][name][parallel[n]].append(
                        states[name][n])

        # fn_act()
        if self._is_agent:
            actions, internals = self.fn_act(
                states=states,
                internals=internals,
                parallel=parallel,
                independent=independent,
                is_internals_none=is_internals_none,
                num_parallel=num_parallel)
        else:
            if batched:
                assert False
            else:
                if self.single_state:
                    if states['state'].shape == (1, ):
                        states = states['state'][0].item()
                    else:
                        states = states['state'][0]
                else:
                    states = util.fmap(function=(lambda x: x[0].item()
                                                 if x.shape ==
                                                 (1, ) else x[0]),
                                       xs=states)
                actions = self.fn_act(states)
                if self.single_action:
                    actions = dict(action=np.asarray([actions]))
                else:
                    actions = util.fmap(function=(lambda x: np.asarray([x])),
                                        xs=actions)

        # Buffer outputs for recording
        if self.recorder_spec is not None and not independent and \
                self.num_episodes >= self.recorder_spec.get('start', 0):
            for n in range(num_parallel):
                for name in self.actions_spec:
                    self.buffers['actions'][name][parallel[n]].append(
                        actions[name][n])

        # Unbatch actions
        if batched:
            # If inputs were batched, turn list of dicts into dict of lists
            function = (lambda x: x.item() if x.shape == () else x)
            if self.single_action:
                actions = input_type(
                    function(actions['action'][n])
                    for n in range(num_parallel))
            else:
                # TODO: recursive
                actions = input_type(
                    OrderedDict(((name, function(x[n]))
                                 for name, x in actions.items()))
                    for n in range(num_parallel))

            if independent and not is_internals_none and is_iter_of_dicts:
                # TODO: recursive
                internals = input_type(
                    OrderedDict(((name, function(x[n]))
                                 for name, x in internals.items()))
                    for n in range(num_parallel))

        else:
            # If inputs were not batched, unbatch outputs
            function = (lambda x: x.item() if x.shape == (1, ) else x[0])
            if self.single_action:
                actions = function(actions['action'])
            else:
                actions = actions.fmap(function=function, cls=OrderedDict)
            if independent and not is_internals_none:
                internals = internals.fmap(function=function, cls=OrderedDict)

        if independent and not is_internals_none:
            return actions, internals
        else:
            return actions
Ejemplo n.º 24
0
    def model_observe(self, parallel, query=None, **kwargs):
        assert self.timestep_completed[parallel]
        index = self.buffer_indices[parallel]
        terminal = self.terminal_buffers[parallel, :index]
        reward = self.reward_buffers[parallel, :index]

        if self.recorder_spec is not None and \
                self.episodes >= self.recorder_spec.get('start', 0):
            for name in self.states_spec:
                self.record_states[name].append(
                    np.array(self.states_buffers[name][parallel, :index]))
            for name, spec in self.actions_spec.items():
                self.record_actions[name].append(
                    np.array(self.actions_buffers[name][parallel, :index]))
                if spec['type'] == 'int':
                    self.record_states[name + '_mask'].append(
                        np.array(
                            self.states_buffers[name +
                                                '_mask'][parallel, :index]))
            self.record_terminal.append(np.array(terminal))
            self.record_reward.append(np.array(reward))

            if terminal[-1] > 0:
                self.num_episodes += 1

                if self.num_episodes == self.recorder_spec.get('frequency', 1):
                    directory = self.recorder_spec['directory']
                    if os.path.isdir(directory):
                        files = sorted(
                            f for f in os.listdir(directory)
                            if os.path.isfile(os.path.join(directory, f))
                            and f.startswith('trace-'))
                    else:
                        os.makedirs(directory)
                        files = list()
                    max_traces = self.recorder_spec.get('max-traces')
                    if max_traces is not None and len(files) > max_traces - 1:
                        for filename in files[:-max_traces + 1]:
                            filename = os.path.join(directory, filename)
                            os.remove(filename)

                    filename = 'trace-{}-{}.npz'.format(
                        self.episodes, time.strftime('%Y%m%d-%H%M%S'))
                    filename = os.path.join(directory, filename)
                    self.record_states = util.fmap(function=np.concatenate,
                                                   xs=self.record_states,
                                                   depth=1)
                    self.record_actions = util.fmap(function=np.concatenate,
                                                    xs=self.record_actions,
                                                    depth=1)
                    self.record_terminal = np.concatenate(self.record_terminal)
                    self.record_reward = np.concatenate(self.record_reward)
                    np.savez_compressed(filename,
                                        **self.record_states,
                                        **self.record_actions,
                                        terminal=self.record_terminal,
                                        reward=self.record_reward)
                    self.record_states = util.fmap(function=(lambda x: list()),
                                                   xs=self.record_states,
                                                   depth=1)
                    self.record_actions = util.fmap(
                        function=(lambda x: list()),
                        xs=self.record_actions,
                        depth=1)
                    self.record_terminal = list()
                    self.record_reward = list()
                    self.num_episodes = 0

        # Reset buffer index
        self.buffer_indices[parallel] = 0

        # Model.observe()
        if query is None:
            updated, self.episodes, self.updates = self.model.observe(
                terminal=terminal,
                reward=reward,
                parallel=[parallel],
                **kwargs)
            return updated

        else:
            updated, self.episodes, self.updates, queried = self.model.observe(
                terminal=terminal,
                reward=reward,
                parallel=[parallel],
                query=query,
                **kwargs)
            return updated, queried
Ejemplo n.º 25
0
    def act(
        self, states, parallel=0, deterministic=False, independent=False, evaluation=False,
        query=None, **kwargs
    ):
        """
        Returns action(s) for the given state(s), needs to be followed by `observe(...)` unless
        `independent` is true.

        Args:
            states (dict[state]): Dictionary containing state(s) to be acted on
                (<span style="color:#C00000"><b>required</b></span>).
            parallel (int): Parallel execution index
                (<span style="color:#00C000"><b>default</b></span>: 0).
            deterministic (bool): Whether to apply exploration and sampling
                (<span style="color:#00C000"><b>default</b></span>: false).
            independent (bool): Whether action is not remembered, and this call is thus not
                followed by observe
                (<span style="color:#00C000"><b>default</b></span>: false).
            evaluation (bool): Whether the agent is currently evaluated, implies and overwrites
                deterministic and independent
                (<span style="color:#00C000"><b>default</b></span>: false).
            query (list[str]): Names of tensors to retrieve
                (<span style="color:#00C000"><b>default</b></span>: none).
            kwargs: Additional input values, for instance, for dynamic hyperparameters.

        Returns:
            (dict[action], plus optional list[str]): Dictionary containing action(s), plus queried
            tensor values if requested.
        """
        assert util.reduce_all(predicate=util.not_nan_inf, xs=states)

        # self.current_internals = self.next_internals
        if evaluation:
            if deterministic or independent:
                raise TensorforceError.unexpected()
            deterministic = independent = True

        # Auxiliaries
        auxiliaries = OrderedDict()
        if isinstance(states, dict):
            states = dict(states)
            for name, spec in self.actions_spec.items():
                if spec['type'] == 'int' and name + '_mask' in states:
                    auxiliaries[name + '_mask'] = states.pop(name + '_mask')

        # Normalize states dictionary
        states = util.normalize_values(
            value_type='state', values=states, values_spec=self.states_spec
        )

        # Batch states
        states = util.fmap(function=(lambda x: np.asarray([x])), xs=states, depth=1)
        auxiliaries = util.fmap(function=(lambda x: np.asarray([x])), xs=auxiliaries, depth=1)

        # Model.act()
        if query is None:
            actions, self.timesteps = self.model.act(
                states=states, auxiliaries=auxiliaries, parallel=parallel,
                deterministic=deterministic, independent=independent, **kwargs
            )

        else:
            actions, self.timesteps, queried = self.model.act(
                states=states, auxiliaries=auxiliaries, parallel=parallel,
                deterministic=deterministic, independent=independent, query=query, **kwargs
            )

        if self.recorder_spec is not None and not independent and \
                self.episodes >= self.recorder_spec.get('start', 0):
            index = self.buffer_indices[parallel]
            for name in self.states_spec:
                self.states_buffers[name][parallel, index] = states[name][0]
            for name, spec in self.actions_spec.items():
                self.actions_buffers[name][parallel, index] = actions[name][0]
                if spec['type'] == 'int':
                    name = name + '_mask'
                    if name in auxiliaries:
                        self.states_buffers[name][parallel, index] = auxiliaries[name][0]
                    else:
                        shape = (1,) + spec['shape'] + (spec['num_values'],)
                        self.states_buffers[name][parallel, index] = np.full(
                            shape=shape, fill_value=True, dtype=util.np_dtype(dtype='bool')
                        )

        # Unbatch actions
        actions = util.fmap(function=(lambda x: x[0]), xs=actions, depth=1)

        # Reverse normalized actions dictionary
        actions = util.unpack_values(
            value_type='action', values=actions, values_spec=self.actions_spec
        )

        # if independent, return processed state as well?

        if query is None:
            return actions
        else:
            return actions, queried
Ejemplo n.º 26
0
    def add_summary(
        self, label, name, tensor, pass_tensors=None, return_summaries=False, mean_variance=False,
        enumerate_last_rank=False
    ):
        # should be "labels" !!!
        # label
        if util.is_iterable(x=label):
            if not all(isinstance(x, str) for x in label):
                raise TensorforceError.value(
                    name='Module.add_summary', argument='label', value=label
                )
        else:
            if not isinstance(label, str):
                raise TensorforceError.type(
                    name='Module.add_summary', argument='label', dtype=type(label)
                )
        # name
        if not isinstance(name, str):
            raise TensorforceError.type(
                name='Module.add_summary', argument='name', dtype=type(name)
            )
        # tensor
        if not isinstance(tensor, (tf.Tensor, tf.Variable)):
            raise TensorforceError.type(
                name='Module.add_summary', argument='tensor', dtype=type(tensor)
            )
        # pass_tensors
        if util.is_iterable(x=pass_tensors):
            if not all(isinstance(x, (tf.Tensor, tf.IndexedSlices)) for x in pass_tensors):
                raise TensorforceError.value(
                    name='Module.add_summary', argument='pass_tensors', value=pass_tensors
                )
        elif pass_tensors is not None:
            if not isinstance(pass_tensors, tf.Tensor):
                raise TensorforceError.type(
                    name='Module.add_summary', argument='pass_tensors', dtype=type(pass_tensors)
                )
        # enumerate_last_rank
        if not isinstance(enumerate_last_rank, bool):
            raise TensorforceError.type(
                name='Module.add_summary', argument='enumerate_last_rank', dtype=type(tensor)
            )

        if pass_tensors is None:
            pass_tensors = tensor

        # Check whether summary is logged
        if not self.is_summary_logged(label=label):
            return pass_tensors

        # Add to available summaries
        if util.is_iterable(x=label):
            self.available_summaries.update(label)
        else:
            self.available_summaries.add(label)

        # Handle enumerate_last_rank
        if enumerate_last_rank:
            dims = util.shape(x=tensor)[-1]
            tensors = OrderedDict([(name + str(n), tensor[..., n]) for n in range(dims)])
        else:
            tensors = OrderedDict([(name, tensor)])

        if mean_variance:
            for name in list(tensors):
                tensor = tensors.pop(name)
                mean, variance = tf.nn.moments(x=tensor, axes=tuple(range(util.rank(x=tensor))))
                tensors[name + '-mean'] = mean
                tensors[name + '-variance'] = variance

        # Scope handling
        if Module.scope_stack is not None:
            for scope in reversed(Module.scope_stack[1:]):
                scope.__exit__(None, None, None)
            if len(Module.global_scope) > 0:
                temp_scope = tf.name_scope(name='/'.join(Module.global_scope))
                temp_scope.__enter__()
            tensors = util.fmap(function=util.identity_operation, xs=tensors)

        # TensorFlow summaries
        assert Module.global_summary_step is not None
        step = Module.retrieve_tensor(name=Module.global_summary_step)
        summaries = list()
        for name, tensor in tensors.items():
            shape = util.shape(x=tensor)
            if shape == ():
                summaries.append(tf.summary.scalar(name=name, data=tensor, step=step))
            elif shape == (-1,):
                tensor = tf.math.reduce_sum(input_tensor=tensor, axis=0)
                summaries.append(tf.summary.scalar(name=name, data=tensor, step=step))
            elif shape == (1,):
                tensor = tf.squeeze(input=tensor, axis=-1)
                summaries.append(tf.summary.scalar(name=name, data=tensor, step=step))
            elif shape == (-1, 1):
                tensor = tf.math.reduce_sum(input_tensor=tf.squeeze(input=tensor, axis=-1), axis=0)
                summaries.append(tf.summary.scalar(name=name, data=tensor, step=step))
            else:
                # General tensor as histogram
                assert not util.is_iterable(x=label) and label.endswith('-histogram')
                summaries.append(tf.summary.histogram(name=name, data=tensor, step=step))

        # Scope handling
        if Module.scope_stack is not None:
            if len(Module.global_scope) > 0:
                temp_scope.__exit__(None, None, None)
            for scope in Module.scope_stack[1:]:
                scope.__enter__()

        with tf.control_dependencies(control_inputs=summaries):
            return util.fmap(function=util.identity_operation, xs=pass_tensors)
Ejemplo n.º 27
0
    def tf_step(self, variables, arguments, fn_loss, **kwargs):
        learning_rate = self.learning_rate.value()
        unperturbed_loss = fn_loss(**arguments)

        deltas = [tf.zeros_like(input=variable) for variable in variables]
        previous_perturbations = [tf.zeros_like(input=variable) for variable in variables]

        if self.unroll_loop:
            # Unrolled for loop
            for sample in range(self.num_samples):
                with tf.control_dependencies(control_inputs=deltas):
                    perturbations = [
                        tf.random.normal(shape=util.shape(variable)) * learning_rate
                        for variable in variables
                    ]
                    perturbation_deltas = [
                        pert - prev_pert
                        for pert, prev_pert in zip(perturbations, previous_perturbations)
                    ]
                    applied = self.apply_step(variables=variables, deltas=perturbation_deltas)
                    previous_perturbations = perturbations

                with tf.control_dependencies(control_inputs=(applied,)):
                    perturbed_loss = fn_loss(**arguments)
                    direction = tf.sign(x=(unperturbed_loss - perturbed_loss))
                    deltas = [
                        delta + direction * perturbation
                        for delta, perturbation in zip(deltas, perturbations)
                    ]

        else:
            # TensorFlow while loop
            def body(deltas, previous_perturbations):
                with tf.control_dependencies(control_inputs=deltas):
                    perturbations = [
                        learning_rate * tf.random.normal(
                            shape=util.shape(x=variable), dtype=util.tf_dtype(dtype='float')
                        ) for variable in variables
                    ]
                    perturbation_deltas = [
                        pert - prev_pert
                        for pert, prev_pert in zip(perturbations, previous_perturbations)
                    ]
                    applied = self.apply_step(variables=variables, deltas=perturbation_deltas)

                with tf.control_dependencies(control_inputs=(applied,)):
                    perturbed_loss = fn_loss(**arguments)
                    direction = tf.sign(x=(unperturbed_loss - perturbed_loss))
                    deltas = [
                        delta + direction * perturbation
                        for delta, perturbation in zip(deltas, perturbations)
                    ]

                return deltas, perturbations

            num_samples = self.num_samples.value()
            deltas, perturbations = self.while_loop(
                cond=util.tf_always_true, body=body, loop_vars=(deltas, previous_perturbations),
                back_prop=False, maximum_iterations=num_samples
            )

        with tf.control_dependencies(control_inputs=deltas):
            num_samples = tf.dtypes.cast(x=num_samples, dtype=util.tf_dtype(dtype='float'))
            deltas = [delta / num_samples for delta in deltas]
            perturbation_deltas = [delta - pert for delta, pert in zip(deltas, perturbations)]
            applied = self.apply_step(variables=variables, deltas=perturbation_deltas)

        with tf.control_dependencies(control_inputs=(applied,)):
            # Trivial operation to enforce control dependency
            return util.fmap(function=util.identity_operation, xs=deltas)
Ejemplo n.º 28
0
 def undo_deltas():
     value = self.fn_x([-delta for delta in deltas])
     with tf.control_dependencies(control_inputs=(value,)):
         return util.fmap(function=util.identity_operation, xs=x_final)