예제 #1
0
    def test_enqueue_dequeue(self):
        """
        Simply tests insert op without checking internal logic.
        """
        fifo_queue = FIFOQueue(capacity=self.capacity,
                               record_space=self.record_space)
        test = ComponentTest(component=fifo_queue,
                             input_spaces=self.input_spaces)

        first_record = self.record_space.sample(size=1)
        test.test(("insert_records", first_record), expected_outputs=None)
        test.test("get_size", expected_outputs=1)

        further_records = self.record_space.sample(size=5)
        test.test(("insert_records", further_records), expected_outputs=None)
        test.test("get_size", expected_outputs=6)

        expected = dict()
        for (k1, v1), (k2, v2) in zip(
                flatten_op(first_record).items(),
                flatten_op(further_records).items()):
            expected[k1] = np.concatenate((v1, v2[:4]))
        expected = unflatten_op(expected)

        test.test(("get_records", 5), expected_outputs=expected)
        test.test("get_size", expected_outputs=1)
예제 #2
0
    def _graph_fn_split_batch(self, *inputs):
        """
        Splits all DataOps in *inputs along their batch dimension into n equally sized shards. The number of shards
        is determined by `self.num_shards` (int) and the size of each shard depends on the incoming batch size with
        possibly a few superfluous items in the batch being discarded
        (effective batch size = num_shards x shard_size).

        Args:
            *input (FlattenedDataOp): Input tensors which must all have the same batch dimension.

        Returns:
            tuple:
                # Each shard consisting of: A DataOpTuple with len = number of input args.
                # - Each item in the DataOpTuple is a FlattenedDataOp with (flat) key (describing the input-piece
                # (e.g. "/states1")) and values being the (now sharded) batch data for that input piece.

                # e.g. return (for 2 shards):
                # tuple(DataOpTuple(input1_flatdict, input2_flatdict, input3_flatdict, input4_flatdict), DataOpTuple([same]))


                List of FlattenedDataOps () containing DataOpTuples containing the input shards.
        """
        if get_backend() == "tf":
            #batch_size = tf.shape(next(iter(inputs[0].values())))[0]
            #shard_size = tf.cast(batch_size / self.num_shards, dtype=tf.int32)

            # Must be evenly divisible so we slice out an evenly divisible tensor.
            # E.g. 203 items in batch with 4 shards -> Only 4 x 50 = 200 are usable.
            usable_size = self.shard_size * self.num_shards

            # List (one item for each input arg). Each item in the list looks like:
            # A FlattenedDataOp with (flat) keys (describing the input-piece (e.g. "/states1")) and values being
            # lists of len n for the n shards' data.
            inputs_flattened_and_split = list()

            for input_arg_data in inputs:
                shard_dict = FlattenedDataOp()
                for flat_key, data in input_arg_data.items():
                    usable_input_tensor = data[:usable_size]
                    shard_dict[flat_key] = tf.split(
                        value=usable_input_tensor,
                        num_or_size_splits=self.num_shards)
                inputs_flattened_and_split.append(shard_dict)

            # Flip the list to generate a new list where each item represents one shard.
            shard_list = list()
            for shard_idx in range(self.num_shards):
                # To be converted into FlattenedDataOps over the input-arg-pieces once complete.
                input_arg_list = list()
                for input_elem in range(len(inputs)):
                    sharded_data_dict = FlattenedDataOp()
                    for flat_key, shards in inputs_flattened_and_split[
                            input_elem].items():
                        sharded_data_dict[flat_key] = shards[shard_idx]
                    input_arg_list.append(unflatten_op(sharded_data_dict))
                # Must store everything as FlattenedDataOp otherwise the re-nesting will not work.
                shard_list.append(DataOpTuple(input_arg_list))

            # Return n values (n = number of batch shards).
            return tuple(shard_list)
예제 #3
0
    def unflatten_output_ops(*ops):
        """
        Re-creates the originally nested input structure (as DataOpDict/DataOpTuple) of the given op-record column.
        Process all FlattenedDataOp with auto-generated keys, and leave the others untouched.

        Args:
            ops (DataOp): The ops that need to be unflattened (only process the FlattenedDataOp
                amongst these and ignore all others).

        Returns:
            Tuple[DataOp]: A tuple containing the ops as they came in, except that all FlattenedDataOp
                have been un-flattened (re-nested) into their original structures.
        """
        # The returned sequence of output ops.
        ret = []

        for i, op in enumerate(ops):
            # A FlattenedDataOp: Try to re-nest it.
            if isinstance(op, FlattenedDataOp):
                ret.append(unflatten_op(op))
            # All others are left as-is.
            else:
                ret.append(op)

        # Always return a tuple for indexing into the return values.
        return tuple(ret)
예제 #4
0
    def test_capacity(self):
        """
        Tests if insert correctly blocks when capacity is reached.
        """
        fifo_queue = FIFOQueue(capacity=self.capacity,
                               record_space=self.record_space)
        test = ComponentTest(component=fifo_queue,
                             input_spaces=self.input_spaces)

        def run(expected_):
            # Wait n seconds.
            time.sleep(2)
            # Pull something out of the queue again to continue.
            test.test(("get_records", 2), expected_outputs=expected_)

        # Insert one more element than capacity
        records = self.record_space.sample(size=self.capacity + 1)

        expected = dict()
        for key, value in flatten_op(records).items():
            expected[key] = value[:2]
        expected = unflatten_op(expected)

        # Start thread to save this one from getting stuck due to capacity overflow.
        thread = threading.Thread(target=run, args=(expected, ))
        thread.start()

        print("Going over capacity: blocking ...")
        test.test(("insert_records", records), expected_outputs=None)
        print("Dequeued some items in another thread. Unblocked.")

        thread.join()
예제 #5
0
 def get_preprocessed_space(self, space):
     # Translate to corresponding FloatBoxes.
     ret = dict()
     for key, value in space.flatten().items():
         ret[key] = FloatBox(shape=value.shape, add_batch_rank=value.has_batch_rank,
                             add_time_rank=value.has_time_rank)
     return unflatten_op(ret)
예제 #6
0
 def get_preprocessed_space(self, space):
     ret = {}
     for key, value in space.flatten().items():
         shape = list(value.shape)
         if self.keep_rank is True:
             shape[-1] = 1
         else:
             shape.pop(-1)
         ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank)
     return unflatten_op(ret)
예제 #7
0
파일: policy.py 프로젝트: theSoenke/rlgraph
    def _graph_fn_get_action_components(self, logits, parameters,
                                        deterministic):
        ret = {}

        # TODO Clean up the checks in here wrt define-by-run processing.
        for flat_key, action_space_component in self.action_space.flatten(
        ).items():
            # Skip our distribution, iff discrete action-space and deterministic acting (greedy).
            # In that case, one does not need to create a distribution in the graph each act (only to get the argmax
            # over the logits, which is the same as the argmax over the probabilities (or log-probabilities)).
            if isinstance(action_space_component, IntBox) and \
                    (deterministic is True or (isinstance(deterministic, np.ndarray) and deterministic)):
                if flat_key == "":
                    return self._graph_fn_get_deterministic_action_wo_distribution(
                        logits)
                else:
                    ret[flat_key] = self._graph_fn_get_deterministic_action_wo_distribution(
                        logits.flat_key_lookup(flat_key))
            elif isinstance(action_space_component, BoolBox) and \
                    (deterministic is True or (isinstance(deterministic, np.ndarray) and deterministic)):
                if get_backend() == "tf":
                    if flat_key == "":
                        return tf.greater(logits, 0.5)
                    else:
                        ret[flat_key] = tf.greater(
                            logits.flat_key_lookup(flat_key), 0.5)
                elif get_backend() == "pytorch":
                    if flat_key == "":
                        return torch.gt(logits, 0.5)
                    else:
                        ret[flat_key] = torch.gt(
                            logits.flat_key_lookup(flat_key), 0.5)
            else:
                if flat_key == "":
                    # Still wrapped as FlattenedDataOp.
                    if isinstance(parameters, FlattenedDataOp):
                        return self.distributions[flat_key].draw(
                            parameters[flat_key], deterministic)
                    else:
                        return self.distributions[flat_key].draw(
                            parameters, deterministic)

                if isinstance(parameters, ContainerDataOp) and not \
                        (isinstance(parameters, DataOpDict) and flat_key in parameters):
                    ret[flat_key] = self.distributions[flat_key].draw(
                        parameters.flat_key_lookup(flat_key), deterministic)
                else:
                    ret[flat_key] = self.distributions[flat_key].draw(
                        parameters[flat_key], deterministic)

        if get_backend() == "tf":
            return unflatten_op(ret)
        elif get_backend() == "pytorch":
            return define_by_run_unflatten(ret)
예제 #8
0
 def get_preprocessed_space(self, space):
     ret = dict()
     for key, single_space in space.flatten().items():
         class_ = type(single_space)
         # We flip batch and time ranks.
         time_major = not single_space.time_major
         ret[key] = class_(shape=single_space.shape,
                           add_batch_rank=single_space.has_batch_rank,
                           add_time_rank=single_space.has_time_rank, time_major=time_major)
         self.output_time_majors[key] = time_major
     ret = unflatten_op(ret)
     return ret
예제 #9
0
파일: reshape.py 프로젝트: ml-lab/rlgraph
    def get_preprocessed_space(self, space):
        ret = {}
        for key, single_space in space.flatten().items():
            class_ = type(single_space)

            # Determine the actual shape (not batch/time ranks).
            if self.flatten is True:
                if type(single_space
                        ) == IntBox and self.flatten_categories is not False:
                    assert self.flatten_categories is not None,\
                        "ERROR: `flatten_categories` must not be None if `flatten` is True and input is IntBox!"
                    new_shape = (self.get_num_categories(key, single_space), )
                    class_ = FloatBox
                else:
                    new_shape = (single_space.flat_dim, )
            else:
                new_shape = self.new_shape[key] if isinstance(
                    self.new_shape, dict) else self.new_shape

            # Check the batch/time rank options.
            if self.fold_time_rank is True:
                sanity_check_space(single_space,
                                   must_have_batch_rank=True,
                                   must_have_time_rank=True)
                ret[key] = class_(shape=single_space.shape
                                  if new_shape is None else new_shape,
                                  add_batch_rank=True,
                                  add_time_rank=False)
            # Time rank should be unfolded from batch rank with the given dimension.
            elif self.unfold_time_rank is True:
                sanity_check_space(single_space,
                                   must_have_batch_rank=True,
                                   must_have_time_rank=False)
                ret[key] = class_(shape=single_space.shape
                                  if new_shape is None else new_shape,
                                  add_batch_rank=True,
                                  add_time_rank=True,
                                  time_major=self.time_major
                                  if self.time_major is not None else False)
            # Only change the actual shape (leave batch/time ranks as is).
            else:
                time_major = single_space.time_major
                ret[key] = class_(shape=single_space.shape
                                  if new_shape is None else new_shape,
                                  add_batch_rank=single_space.has_batch_rank,
                                  add_time_rank=single_space.has_time_rank,
                                  time_major=time_major)
        ret = unflatten_op(ret)
        return ret
예제 #10
0
 def get_preprocessed_space(self, space):
     ret = dict()
     for key, value in space.flatten().items():
         # Do some sanity checking.
         rank = value.rank
         assert rank == 2 or rank == 3, \
             "ERROR: Given image's rank (which is {}{}, not counting batch rank) must be either 2 or 3!".\
             format(rank, ("" if key == "" else " for key '{}'".format(key)))
         # Determine the output shape.
         shape = list(value.shape)
         shape[0] = self.width
         shape[1] = self.height
         ret[key] = value.__class__(shape=tuple(shape),
                                    add_batch_rank=value.has_batch_rank)
     return unflatten_op(ret)
예제 #11
0
    def get_preprocessed_space(self, space):
        ret = dict()
        for key, single_space in space.flatten().items():
            class_ = type(single_space)

            # Determine the actual shape (not batch/time ranks).
            if self.flatten is True:
                if self.flatten_categories is not False and type(single_space) == IntBox:
                    if self.flatten_categories is True:
                        num_categories = single_space.flat_dim_with_categories
                    else:
                        num_categories = self.flatten_categories
                    new_shape = (num_categories,)
                else:
                    new_shape = (single_space.flat_dim,)

                if self.flatten_categories is not False and type(single_space) == IntBox:
                    class_ = FloatBox
            else:
                new_shape = self.new_shape[key] if isinstance(self.new_shape, dict) else self.new_shape

            # Check the batch/time rank options.
            if self.fold_time_rank is True:
                sanity_check_space(single_space, must_have_batch_rank=True, must_have_time_rank=True)
                ret[key] = class_(
                    shape=single_space.shape if new_shape is None else new_shape,
                    add_batch_rank=True, add_time_rank=False
                )
            # Time rank should be unfolded from batch rank with the given dimension.
            elif self.unfold_time_rank is True:
                sanity_check_space(single_space, must_have_batch_rank=True, must_have_time_rank=False)
                ret[key] = class_(
                    shape=single_space.shape if new_shape is None else new_shape,
                    add_batch_rank=True, add_time_rank=True,
                    time_major=self.time_major if self.time_major is not None else False
                )
            # Only change the actual shape (leave batch/time ranks as is).
            else:
                # Do we flip batch and time ranks?
                time_major = single_space.time_major if self.flip_batch_and_time_rank is False else \
                    not single_space.time_major

                ret[key] = class_(shape=single_space.shape if new_shape is None else new_shape,
                                  add_batch_rank=single_space.has_batch_rank,
                                  add_time_rank=single_space.has_time_rank, time_major=time_major)
        ret = unflatten_op(ret)
        return ret
예제 #12
0
    def get_preprocessed_space(self, space):
        ret = dict()
        for key, value in space.flatten().items():
            shape = list(value.shape)
            if self.add_rank:
                shape.append(self.sequence_length)
            else:
                shape[-1] *= self.sequence_length

            # TODO move to transpose component.
            # Transpose.
            if self.in_data_format == "channels_last" and self.out_data_format == "channels_first":
                shape.reverse()
                ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank)
            else:
                ret[key] = value.__class__(shape=tuple(shape), add_batch_rank=value.has_batch_rank)
        return unflatten_op(ret)
예제 #13
0
    def get_preprocessed_space(self, space):
        """
        Returns the Space obtained after pushing the input through all layers of this Stack.

        Args:
            space (Dict): The incoming Space object.

        Returns:
            Space: The Space after preprocessing.
        """
        assert isinstance(space, ContainerSpace)
        dict_spec = dict()
        for flat_key, sub_space in space.flatten().items():
            if flat_key in self.flattened_preprocessors:
                dict_spec[flat_key] = self.flattened_preprocessors[flat_key].get_preprocessed_space(sub_space)
            else:
                dict_spec[flat_key] = sub_space
        dict_spec = unflatten_op(dict_spec)
        return Dict(dict_spec)
예제 #14
0
    def _graph_fn_unstage(self):
        """
        Unstages (and unflattens) all staged data.

        Returns:
            Tuple[DataOp]: All previously staged ops.
        """
        unstaged_data = self.area.get()
        unflattened_data = list()
        idx = 0
        # Unflatten all data and return.
        for flat_key_list in self.flat_keys:
            flat_dict = FlattenedDataOp({
                flat_key: item
                for flat_key, item in zip(
                    flat_key_list, unstaged_data[idx:idx + len(flat_key_list)])
            })
            unflattened_data.append(unflatten_op(flat_dict))
            idx += len(flat_key_list)

        return tuple(unflattened_data)
예제 #15
0
    def get_preprocessed_space(self, space):
        ## Test sending np samples to get number of return values and output spaces without having to call
        ## the tf graph_fn.
        #backend = self.backend
        #self.backend = "python"
        #sample = space.sample(size=1)
        #out = self._graph_fn_call(sample)
        #new_space = get_space_from_op(out)
        #self.backend = backend
        #return new_space

        ret = dict()
        for key, value in space.flatten().items():
            # Do some sanity checking.
            rank = value.rank
            if get_backend() == "tf":
                assert rank == 2 or rank == 3, \
                    "ERROR: Given image's rank (which is {}{}, not counting batch rank) must be either 2 or 3!".\
                    format(rank, ("" if key == "" else " for key '{}'".format(key)))
                # Determine the output shape.
                shape = list(value.shape)
                shape[0] = self.width
                shape[1] = self.height
            elif get_backend() == "pytorch":
                shape = list(value.shape)

                # Determine the output shape.
                if rank == 3:
                    shape[0] = self.width
                    shape[1] = self.height
                elif rank == 4:
                    # TODO PyTorch shape inference issue.
                    shape[1] = self.width
                    shape[2] = self.height
            ret[key] = value.__class__(shape=tuple(shape),
                                       add_batch_rank=value.has_batch_rank)
        return unflatten_op(ret)
예제 #16
0
            def scan_func(accum, time_delta):
                # Not needed: preprocessed-previous-states (tuple!)
                # `state` is a tuple as well. See comment in ctor for why tf cannot use ContainerSpaces here.
                internal_states = None
                state = accum[1]
                if self.has_rnn:
                    internal_states = accum[-1]

                state = tuple(tf.convert_to_tensor(value=s) for s in state)

                flat_state = OrderedDict()
                for i, flat_key in enumerate(
                        self.state_space_actor_flattened.keys()):
                    # Add a simple (size 1) batch rank to the state so it'll pass through the NN.
                    # - Also have to add a time-rank for RNN processing.
                    expanded = state[i]
                    for _ in range(1 if self.has_rnn is False else 2):
                        expanded = tf.expand_dims(input=expanded, axis=0)
                    # Make None so it'll be recognized as batch-rank by the auto-Space detector.
                    flat_state[flat_key] = tf.placeholder_with_default(
                        input=expanded,
                        shape=(None, ) + ((None, ) if self.has_rnn is True else
                                          ()) +
                        self.state_space_actor_list[i].shape)

                # Recreate state as the original Space to pass it into the actor-component.
                state = unflatten_op(flat_state)

                # Get action and preprocessed state (as batch-size 1).
                out = (self.actor_component.get_preprocessed_state_and_action
                       if self.add_action_probs is False else
                       self.actor_component.
                       get_preprocessed_state_action_and_action_probs)(
                           state,
                           # Add simple batch rank to internal_states.
                           None if internal_states is None else DataOpTuple(
                               internal_states),  # <- None for non-RNN systems
                           time_step=self.time_step + time_delta,
                           return_ops=True)

                # Get output depending on whether it contains internal_states or not.
                a = out["action"]
                action_probs = out.get("action_probs")
                current_internal_states = out.get("last_internal_states")

                # Strip the batch (and maybe time) ranks again from the action in case the Env doesn't like it.
                a_no_extra_ranks = a[0, 0] if self.has_rnn is True else a[0]
                # Step through the Env and collect next state (tuple!), reward and terminal as single values
                # (not batched).
                out = self.environment_server.step_for_env_stepper(
                    a_no_extra_ranks)
                s_, r, t_ = out[:-2], out[-2], out[-1]
                r = tf.cast(r, dtype="float32")

                # Add a and/or r to next_state?
                if self.add_previous_action_to_state is True:
                    assert isinstance(
                        s_, tuple
                    ), "ERROR: Cannot add previous action to non tuple!"
                    s_ = s_ + (a_no_extra_ranks, )
                if self.add_previous_reward_to_state is True:
                    assert isinstance(
                        s_, tuple
                    ), "ERROR: Cannot add previous reward to non tuple!"
                    s_ = s_ + (r, )

                # Note: s_ is packed as tuple.
                ret = [t_, s_] + \
                    ([a_no_extra_ranks] if self.add_action else []) + \
                    ([r] if self.add_reward else []) + \
                    ([(action_probs[0][0] if self.has_rnn is True else action_probs[0])] if
                     self.add_action_probs is True else []) + \
                    ([tuple(current_internal_states)] if self.has_rnn is True else [])

                return tuple(ret)
예제 #17
0
    def _graph_fn_step(self):
        if get_backend() == "tf":

            def scan_func(accum, time_delta):
                # Not needed: preprocessed-previous-states (tuple!)
                # `state` is a tuple as well. See comment in ctor for why tf cannot use ContainerSpaces here.
                internal_states = None
                state = accum[1]
                if self.has_rnn:
                    internal_states = accum[-1]

                state = tuple(tf.convert_to_tensor(value=s) for s in state)

                flat_state = OrderedDict()
                for i, flat_key in enumerate(
                        self.state_space_actor_flattened.keys()):
                    # Add a simple (size 1) batch rank to the state so it'll pass through the NN.
                    # - Also have to add a time-rank for RNN processing.
                    expanded = state[i]
                    for _ in range(1 if self.has_rnn is False else 2):
                        expanded = tf.expand_dims(input=expanded, axis=0)
                    # Make None so it'll be recognized as batch-rank by the auto-Space detector.
                    flat_state[flat_key] = tf.placeholder_with_default(
                        input=expanded,
                        shape=(None, ) + ((None, ) if self.has_rnn is True else
                                          ()) +
                        self.state_space_actor_list[i].shape)

                # Recreate state as the original Space to pass it into the actor-component.
                state = unflatten_op(flat_state)

                # Get action and preprocessed state (as batch-size 1).
                out = (self.actor_component.get_preprocessed_state_and_action
                       if self.add_action_probs is False else
                       self.actor_component.
                       get_preprocessed_state_action_and_action_probs)(
                           state,
                           # Add simple batch rank to internal_states.
                           None if internal_states is None else DataOpTuple(
                               internal_states),  # <- None for non-RNN systems
                           time_step=self.time_step + time_delta,
                           return_ops=True)

                # Get output depending on whether it contains internal_states or not.
                a = out["action"]
                action_probs = out.get("action_probs")
                current_internal_states = out.get("last_internal_states")

                # Strip the batch (and maybe time) ranks again from the action in case the Env doesn't like it.
                a_no_extra_ranks = a[0, 0] if self.has_rnn is True else a[0]
                # Step through the Env and collect next state (tuple!), reward and terminal as single values
                # (not batched).
                out = self.environment_server.step_for_env_stepper(
                    a_no_extra_ranks)
                s_, r, t_ = out[:-2], out[-2], out[-1]
                r = tf.cast(r, dtype="float32")

                # Add a and/or r to next_state?
                if self.add_previous_action_to_state is True:
                    assert isinstance(
                        s_, tuple
                    ), "ERROR: Cannot add previous action to non tuple!"
                    s_ = s_ + (a_no_extra_ranks, )
                if self.add_previous_reward_to_state is True:
                    assert isinstance(
                        s_, tuple
                    ), "ERROR: Cannot add previous reward to non tuple!"
                    s_ = s_ + (r, )

                # Note: s_ is packed as tuple.
                ret = [t_, s_] + \
                    ([a_no_extra_ranks] if self.add_action else []) + \
                    ([r] if self.add_reward else []) + \
                    ([(action_probs[0][0] if self.has_rnn is True else action_probs[0])] if
                     self.add_action_probs is True else []) + \
                    ([tuple(current_internal_states)] if self.has_rnn is True else [])

                return tuple(ret)

            # Initialize the tf.scan run.
            initializer = [
                self.current_terminal.read_value(
                ),  # whether the current state is terminal
                # current (raw) state (flattened components if ContainerSpace).
                tuple(
                    map(lambda x: x.read_value(), self.current_state.values()))
            ]
            # Append actions and rewards if needed.
            if self.add_action:
                initializer.append(self.current_action.read_value())
            if self.add_reward:
                initializer.append(self.current_reward.read_value())
            # Append action probs if needed.
            if self.add_action_probs is True:
                initializer.append(self.current_action_probs.read_value())
            # Append internal states if needed.
            if self.current_internal_states is not None:
                initializer.append(
                    tuple(
                        tf.placeholder_with_default(
                            internal_s.read_value(),
                            shape=(None, ) +
                            tuple(internal_s.shape.as_list()[1:])) for
                        internal_s in self.current_internal_states.values()))

            # Scan over n time-steps (tf.range produces the time_delta with respect to the current time_step).
            # NOTE: Changed parallel to 1, to resolve parallel issues.
            step_results = list(
                tf.scan(fn=scan_func,
                        elems=tf.range(self.num_steps, dtype="int32"),
                        initializer=tuple(initializer),
                        back_prop=False))

            # Store the time-step increment, return so far, current terminal and current state.
            assigns = [
                tf.assign_add(self.time_step, self.num_steps),
                self.assign_variable(self.current_terminal,
                                     step_results[0][-1])
            ]

            # Concatenate first and rest.
            full_results = []
            for first_values, rest_values in zip(initializer, step_results):
                full_results.append(
                    nest.map_structure(
                        lambda first, rest: tf.concat([[first], rest], axis=0),
                        first_values, rest_values))

            # Re-build DataOpDicts from preprocessed-states and states (from tuple right now).
            rebuild_s = DataOpDict()
            for flat_key, var_ref, s_comp in zip(
                    self.state_space_actor_flattened.keys(),
                    self.current_state.values(), full_results[1]):
                assigns.append(self.assign_variable(
                    var_ref, s_comp[-1]))  # -1: current state (last observed)
                rebuild_s[flat_key] = s_comp
            rebuild_s = unflatten_op(rebuild_s)
            full_results[1] = rebuild_s

            # Remove batch rank from internal states again.
            if self.current_internal_states is not None:
                # TODO: What if internal states is not the last item in the list anymore due to some change.
                slot = -1  # if self.add_action_probs is True else 2
                # TODO: What if internal states is a dict? Right now assume some tuple.
                internal_states_wo_batch = list()
                for i in range(len(full_results[slot])):
                    # 1=batch axis (which is 1); 0=time axis.
                    internal_states_wo_batch.append(
                        tf.squeeze(full_results[-1][i], axis=1))
                full_results[slot] = DataOpTuple(internal_states_wo_batch)

            with tf.control_dependencies(control_inputs=assigns):
                # Let the auto-infer system know, what time rank we have.
                full_results = DataOpTuple(full_results)
                for o in flatten_op(full_results).values():
                    o._time_rank = 0  # which position in the shape is the time-rank?
                step_op = tf.no_op()

            return step_op, full_results
예제 #18
0
 def map(self, mapping):
     flattened_self = self.flatten(mapping=mapping)
     return Tuple(
         unflatten_op(flattened_self),
         add_batch_rank=self.has_batch_rank, add_time_rank=self.has_time_rank, time_major=self.time_major
     )
예제 #19
0
                def opt_body(index_, loss_, loss_per_item_, vf_loss_,
                             vf_loss_per_item_):
                    start = tf.random_uniform(shape=(),
                                              minval=0,
                                              maxval=batch_size - 1,
                                              dtype=tf.int32)
                    indices = tf.range(
                        start=start,
                        limit=start + agent.sample_size) % batch_size
                    sample_states = tf.gather(params=preprocessed_states,
                                              indices=indices)
                    if isinstance(actions, ContainerDataOp):
                        sample_actions = FlattenedDataOp()
                        for name, action in flatten_op(actions).items():
                            sample_actions[name] = tf.gather(params=action,
                                                             indices=indices)
                        sample_actions = unflatten_op(sample_actions)
                    else:
                        sample_actions = tf.gather(params=actions,
                                                   indices=indices)

                    sample_prior_log_probs = tf.gather(params=prev_log_probs,
                                                       indices=indices)
                    sample_rewards = tf.gather(params=rewards, indices=indices)
                    sample_terminals = tf.gather(params=terminals,
                                                 indices=indices)
                    sample_sequence_indices = tf.gather(
                        params=sequence_indices, indices=indices)
                    sample_advantages = tf.gather(params=advantages,
                                                  indices=indices)
                    sample_advantages.set_shape((self.sample_size, ))

                    sample_baseline_values = value_function.value_output(
                        sample_states)
                    sample_prior_baseline_values = tf.gather(
                        params=prior_baseline_values, indices=indices)

                    # If we are a multi-GPU root:
                    # Simply feeds everything into the multi-GPU sync optimizer's method and return.
                    if multi_gpu_sync_optimizer is not None:
                        main_policy_vars = agent.policy.variables()
                        main_vf_vars = agent.value_function.variables()
                        all_vars = agent.vars_merger.merge(
                            main_policy_vars, main_vf_vars)
                        # grads_and_vars, loss, loss_per_item, vf_loss, vf_loss_per_item = \
                        out = multi_gpu_sync_optimizer.calculate_update_from_external_batch(
                            all_vars,
                            sample_states,
                            sample_actions,
                            sample_rewards,
                            sample_terminals,
                            sample_sequence_indices,
                            apply_postprocessing=apply_postprocessing)
                        avg_grads_and_vars_policy, avg_grads_and_vars_vf = agent.vars_splitter.call(
                            out["avg_grads_and_vars_by_component"])
                        policy_step_op = agent.optimizer.apply_gradients(
                            avg_grads_and_vars_policy)
                        vf_step_op = agent.value_function_optimizer.apply_gradients(
                            avg_grads_and_vars_vf)
                        step_op = root._graph_fn_group(policy_step_op,
                                                       vf_step_op)
                        step_and_sync_op = multi_gpu_sync_optimizer.sync_variables_to_towers(
                            step_op, all_vars)
                        loss_vf, loss_per_item_vf = out[
                            "additional_return_0"], out["additional_return_1"]

                        # Have to set all shapes here due to strict loop-var shape requirements.
                        out["loss"].set_shape(())
                        loss_vf.set_shape(())
                        loss_per_item_vf.set_shape((agent.sample_size, ))
                        out["loss_per_item"].set_shape((agent.sample_size, ))

                        with tf.control_dependencies([step_and_sync_op]):
                            if index_ == 0:
                                # Increase the global training step counter.
                                out["loss"] = root._graph_fn_training_step(
                                    out["loss"])
                            return index_ + 1, out["loss"], out[
                                "loss_per_item"], loss_vf, loss_per_item_vf

                    policy_probs = policy.get_log_likelihood(
                        sample_states, sample_actions)["log_likelihood"]
                    baseline_values = value_function.value_output(
                        tf.stop_gradient(sample_states))
                    sample_rewards = tf.cond(
                        pred=apply_postprocessing,
                        true_fn=lambda: gae_function.calc_gae_values(
                            baseline_values, sample_rewards, sample_terminals,
                            sample_sequence_indices),
                        false_fn=lambda: sample_rewards)
                    sample_rewards.set_shape((agent.sample_size, ))
                    entropy = policy.get_entropy(sample_states)["entropy"]

                    loss, loss_per_item, vf_loss, vf_loss_per_item = \
                        loss_function.loss(
                            policy_probs, sample_prior_log_probs,
                            sample_baseline_values, sample_prior_baseline_values, sample_advantages, entropy
                        )

                    if hasattr(root, "is_multi_gpu_tower"
                               ) and root.is_multi_gpu_tower is True:
                        policy_grads_and_vars = optimizer.calculate_gradients(
                            policy.variables(), loss)
                        vf_grads_and_vars = value_function_optimizer.calculate_gradients(
                            value_function.variables(), vf_loss)
                        grads_and_vars_by_component = vars_merger.merge(
                            policy_grads_and_vars, vf_grads_and_vars)
                        return grads_and_vars_by_component, loss, loss_per_item, vf_loss, vf_loss_per_item
                    else:
                        step_op, loss, loss_per_item = optimizer.step(
                            policy.variables(), loss, loss_per_item)
                        loss.set_shape(())
                        loss_per_item.set_shape((agent.sample_size, ))

                        vf_step_op, vf_loss, vf_loss_per_item = value_function_optimizer.step(
                            value_function.variables(), vf_loss,
                            vf_loss_per_item)
                        vf_loss.set_shape(())
                        vf_loss_per_item.set_shape((agent.sample_size, ))

                        with tf.control_dependencies([step_op, vf_step_op]):
                            return index_ + 1, loss, loss_per_item, vf_loss, vf_loss_per_item