Esempio n. 1
0
def define_by_run_flatten(container, key_scope="", tensor_tuple_list=None, scope_separator_at_start=True):
    """
    Flattens a native python dict/tuple into a flat dict with auto-key generation. Run-time equivalent
    to build-time flatten operation.

    Args:
        container (Union[dict,tuple]): Container  to flatten.
        key_scope (str): The recursive scope for auto-key generation.
        tensor_tuple_list (list): The list of tuples (key, value) to be converted into the final results.
        scope_separator_at_start (bool): If to prepend a scope separator before the first key in a
            recursive structure. Default false.

    Returns:
        Dict: Flattened container.
    """
    ret = False

    # Are we in the non-recursive (first) call?
    if tensor_tuple_list is None:
        tensor_tuple_list = []
        if not isinstance(container, (dict, tuple)):
            return DataOpDict([("", container)])
        ret = True

    if isinstance(container, dict):
        if scope_separator_at_start:
            key_scope += FLATTEN_SCOPE_PREFIX
        else:
            key_scope = ""
        for key in sorted(container.keys()):
            # Make sure we have no double slashes from flattening an already FlattenedDataOp.
            scope = (key_scope[:-1] if len(key) == 0 or key[0] == "/" else key_scope) + key
            define_by_run_flatten(container[key], key_scope=scope, tensor_tuple_list=tensor_tuple_list, scope_separator_at_start=True)
    elif isinstance(container, tuple):
        if scope_separator_at_start:
            key_scope += FLATTEN_SCOPE_PREFIX + FLAT_TUPLE_OPEN
        else:
            key_scope += "" + FLAT_TUPLE_OPEN
        for i, c in enumerate(container):
            define_by_run_flatten(c, key_scope=key_scope + str(i) + FLAT_TUPLE_CLOSE, tensor_tuple_list=tensor_tuple_list,
                                  scope_separator_at_start=True)
    else:
        assert not isinstance(container, (dict, tuple))
        tensor_tuple_list.append((key_scope, container))

    # Non recursive (first) call -> Return the final dict.
    if ret:
        return DataOpDict(tensor_tuple_list)
Esempio n. 2
0
    def _graph_fn_merge(self, *inputs):
        """
        Merges the inputs into a single DataOpDict OR DataOpTuple with the flat keys given in `self.dict_keys`.

        Args:
            *inputs (FlattenedDataOp): The input items to be merged into a ContainerDataOp.

        Returns:
            ContainerDataOp: The DataOpDict or DataOpTuple as a merger of all *inputs.
        """
        if self.is_tuple is True:
            ret = []
            for op in inputs:
                # Merge single items inside a DataOpTuple into resulting tuple.
                if self.merge_tuples_into_one and isinstance(op, DataOpTuple):
                    ret.extend(list(op))
                # Strict by-input merging.
                else:
                    ret.append(op)
            return DataOpTuple(ret)
        else:
            ret = DataOpDict()
            for i, op in enumerate(inputs):
                if get_backend() == "pytorch" and self.execution_mode == "define_by_run":
                    ret[FLATTEN_SCOPE_PREFIX + self.dict_keys[i]] = op
                else:
                    ret[self.dict_keys[i]] = op
            return ret
Esempio n. 3
0
    def _graph_fn_merge(self, *inputs):
        """
        Merges the inputs into a single FlattenedDataOp with the flat keys given in `self.input_names`.

        Args:
            *inputs (FlattenedDataOp): The input items to be merged into a FlattenedDataOp.

        Returns:
            FlattenedDataOp: The FlattenedDataOp as a merger of all api_methods.
        """
        ret = DataOpDict()
        for i, op in enumerate(inputs):
            ret[self.input_names[i]] = op
        return ret
Esempio n. 4
0
    def _average_grads_and_vars(self, variables_by_component,
                                grads_and_vars_all_gpus_by_component):
        """
        Utility to average gradients (per var) across towers.

        Args:
            variables_by_component (DataOpDict[Dict[str,DataOp]]): Dict of Dict of variables.
            grads_and_vars_all_gpus_by_component (DataOpDict[??]]): Dict of grads_and_vars lists.

        Returns:
            DataOpDict[str,list]: DataOpDict with keys=component keys, values=list of grads_and_vars tuples averaged
                across our GPUs.
        """
        if get_backend() == "tf":
            ret = dict()
            for component_key in variables_by_component.keys():
                gpu_grad_averages = []
                for i, grads_and_vars in enumerate(
                        zip(*
                            grads_and_vars_all_gpus_by_component[component_key]
                            )):
                    gpu_grads = []

                    for grad, var in grads_and_vars:
                        if grad is not None:
                            # Add batch dimension.
                            batch_grad = tf.expand_dims(input=grad, axis=0)

                            # Add along axis for that gpu.
                            gpu_grads.append(batch_grad)

                    if not gpu_grads:
                        continue

                    aggregate_grads = tf.concat(axis=0, values=gpu_grads)
                    mean_grad = tf.reduce_mean(input_tensor=aggregate_grads,
                                               axis=0)
                    # Need the actual main policy vars, as these are the ones that should be updated.
                    # TODO: This is a hack and needs to be changed, but it works for now to look up main policy variables.
                    main_variable_key = re.sub(
                        r'{}/tower-0/'.format(self.global_scope), "",
                        grads_and_vars[0][1].op.name)
                    main_variable_key = re.sub(r'/', "-", main_variable_key)
                    var = variables_by_component[component_key][
                        main_variable_key]
                    gpu_grad_averages.append((mean_grad, var))

                ret[component_key] = DataOpTuple(gpu_grad_averages)

            return DataOpDict(ret)
Esempio n. 5
0
    def _graph_fn_merge(self, *inputs):
        """
        Merges the inputs into a single DataOpDict OR DataOpTuple with the flat keys given in `self.dict_keys`.

        Args:
            *inputs (FlattenedDataOp): The input items to be merged into a ContainerDataOp.

        Returns:
            ContainerDataOp: The DataOpDict or DataOpTuple as a merger of all *inputs.
        """
        if self.num_items is not None:
            ret = list()
            for op in inputs:
                ret.append(op)
            return DataOpTuple(ret)
        else:
            ret = DataOpDict()
            for i, op in enumerate(inputs):
                ret[self.dict_keys[i]] = op
            return ret
Esempio n. 6
0
    def _graph_fn_merge(self, *inputs):
        """
        Merges the inputs into a single DataOpDict OR DataOpTuple with the flat keys given in `self.dict_keys`.

        Args:
            *inputs (FlattenedDataOp): The input items to be merged into a ContainerDataOp.

        Returns:
            ContainerDataOp: The DataOpDict or DataOpTuple as a merger of all *inputs.
        """
        if self.num_items is not None:
            ret = []
            for op in inputs:
                ret.append(op)
            return DataOpTuple(ret)
        else:
            ret = DataOpDict()
            for i, op in enumerate(inputs):
                if get_backend() == "pytorch" and self.execution_mode == "define_by_run":
                    ret[FLATTEN_SCOPE_PREFIX + self.dict_keys[i]] = op
                else:
                    ret[self.dict_keys[i]] = op
            return ret
Esempio n. 7
0
 def dtype(self):
     return DataOpDict([(key, subspace.dtype) for key, subspace in self.items()])
Esempio n. 8
0
    def _graph_fn_step(self):
        if get_backend() == "tf":

            def scan_func(accum, time_delta):
                # Not needed: preprocessed-previous-states (tuple!)
                # `state` is a tuple as well. See comment in ctor for why tf cannot use ContainerSpaces here.
                internal_states = None
                state = accum[1]
                if self.has_rnn:
                    internal_states = accum[-1]

                state = tuple(tf.convert_to_tensor(value=s) for s in state)

                flat_state = OrderedDict()
                for i, flat_key in enumerate(
                        self.state_space_actor_flattened.keys()):
                    # Add a simple (size 1) batch rank to the state so it'll pass through the NN.
                    # - Also have to add a time-rank for RNN processing.
                    expanded = state[i]
                    for _ in range(1 if self.has_rnn is False else 2):
                        expanded = tf.expand_dims(input=expanded, axis=0)
                    # Make None so it'll be recognized as batch-rank by the auto-Space detector.
                    flat_state[flat_key] = tf.placeholder_with_default(
                        input=expanded,
                        shape=(None, ) + ((None, ) if self.has_rnn is True else
                                          ()) +
                        self.state_space_actor_list[i].shape)

                # Recreate state as the original Space to pass it into the actor-component.
                state = unflatten_op(flat_state)

                # Get action and preprocessed state (as batch-size 1).
                out = (self.actor_component.get_preprocessed_state_and_action
                       if self.add_action_probs is False else
                       self.actor_component.
                       get_preprocessed_state_action_and_action_probs)(
                           state,
                           # Add simple batch rank to internal_states.
                           None if internal_states is None else DataOpTuple(
                               internal_states),  # <- None for non-RNN systems
                           time_step=self.time_step + time_delta,
                           return_ops=True)

                # Get output depending on whether it contains internal_states or not.
                a = out["action"]
                action_probs = out.get("action_probs")
                current_internal_states = out.get("last_internal_states")

                # Strip the batch (and maybe time) ranks again from the action in case the Env doesn't like it.
                a_no_extra_ranks = a[0, 0] if self.has_rnn is True else a[0]
                # Step through the Env and collect next state (tuple!), reward and terminal as single values
                # (not batched).
                out = self.environment_server.step_for_env_stepper(
                    a_no_extra_ranks)
                s_, r, t_ = out[:-2], out[-2], out[-1]
                r = tf.cast(r, dtype="float32")

                # Add a and/or r to next_state?
                if self.add_previous_action_to_state is True:
                    assert isinstance(
                        s_, tuple
                    ), "ERROR: Cannot add previous action to non tuple!"
                    s_ = s_ + (a_no_extra_ranks, )
                if self.add_previous_reward_to_state is True:
                    assert isinstance(
                        s_, tuple
                    ), "ERROR: Cannot add previous reward to non tuple!"
                    s_ = s_ + (r, )

                # Note: s_ is packed as tuple.
                ret = [t_, s_] + \
                    ([a_no_extra_ranks] if self.add_action else []) + \
                    ([r] if self.add_reward else []) + \
                    ([(action_probs[0][0] if self.has_rnn is True else action_probs[0])] if
                     self.add_action_probs is True else []) + \
                    ([tuple(current_internal_states)] if self.has_rnn is True else [])

                return tuple(ret)

            # Initialize the tf.scan run.
            initializer = [
                self.current_terminal.read_value(
                ),  # whether the current state is terminal
                # current (raw) state (flattened components if ContainerSpace).
                tuple(
                    map(lambda x: x.read_value(), self.current_state.values()))
            ]
            # Append actions and rewards if needed.
            if self.add_action:
                initializer.append(self.current_action.read_value())
            if self.add_reward:
                initializer.append(self.current_reward.read_value())
            # Append action probs if needed.
            if self.add_action_probs is True:
                initializer.append(self.current_action_probs.read_value())
            # Append internal states if needed.
            if self.current_internal_states is not None:
                initializer.append(
                    tuple(
                        tf.placeholder_with_default(
                            internal_s.read_value(),
                            shape=(None, ) +
                            tuple(internal_s.shape.as_list()[1:])) for
                        internal_s in self.current_internal_states.values()))

            # Scan over n time-steps (tf.range produces the time_delta with respect to the current time_step).
            # NOTE: Changed parallel to 1, to resolve parallel issues.
            step_results = list(
                tf.scan(fn=scan_func,
                        elems=tf.range(self.num_steps, dtype="int32"),
                        initializer=tuple(initializer),
                        back_prop=False))

            # Store the time-step increment, return so far, current terminal and current state.
            assigns = [
                tf.assign_add(self.time_step, self.num_steps),
                self.assign_variable(self.current_terminal,
                                     step_results[0][-1])
            ]

            # Concatenate first and rest.
            full_results = []
            for first_values, rest_values in zip(initializer, step_results):
                full_results.append(
                    nest.map_structure(
                        lambda first, rest: tf.concat([[first], rest], axis=0),
                        first_values, rest_values))

            # Re-build DataOpDicts from preprocessed-states and states (from tuple right now).
            rebuild_s = DataOpDict()
            for flat_key, var_ref, s_comp in zip(
                    self.state_space_actor_flattened.keys(),
                    self.current_state.values(), full_results[1]):
                assigns.append(self.assign_variable(
                    var_ref, s_comp[-1]))  # -1: current state (last observed)
                rebuild_s[flat_key] = s_comp
            rebuild_s = unflatten_op(rebuild_s)
            full_results[1] = rebuild_s

            # Remove batch rank from internal states again.
            if self.current_internal_states is not None:
                # TODO: What if internal states is not the last item in the list anymore due to some change.
                slot = -1  # if self.add_action_probs is True else 2
                # TODO: What if internal states is a dict? Right now assume some tuple.
                internal_states_wo_batch = list()
                for i in range(len(full_results[slot])):
                    # 1=batch axis (which is 1); 0=time axis.
                    internal_states_wo_batch.append(
                        tf.squeeze(full_results[-1][i], axis=1))
                full_results[slot] = DataOpTuple(internal_states_wo_batch)

            with tf.control_dependencies(control_inputs=assigns):
                # Let the auto-infer system know, what time rank we have.
                full_results = DataOpTuple(full_results)
                for o in flatten_op(full_results).values():
                    o._time_rank = 0  # which position in the shape is the time-rank?
                step_op = tf.no_op()

            return step_op, full_results
Esempio n. 9
0
 def _graph_fn_merge_actions(self, a, b, c):
     return DataOpDict(a=a, b=b, c=c)
Esempio n. 10
0
def define_by_run_unflatten(result_dict):
    """
    Takes a dict with auto-generated keys and returns the corresponding
    unflattened dict.
    If the only key in the input dict is "", it returns the value under
    that key.

    Args:
        result_dict (dict): The item to be unflattened (re-nested).

    Returns:
        Dict: The unflattened (re-nested) item.
    """
    # Special case: Dict with only 1 value (key="")
    if len(result_dict) == 1 and "" in result_dict:
        return result_dict[""]

    # Normal case: OrderedDict that came from a ContainerItem.
    base_structure = None

    op_names = sorted(result_dict.keys())
    for op_name in op_names:
        op_val = result_dict[op_name]
        parent_structure = None
        parent_key = None
        current_structure = None
        op_type = None

        if op_name.startswith(FLATTEN_SCOPE_PREFIX):
            op_name = op_name[1:]
        # N.b. removed this because we do not prepend / any more before first key.
        op_key_list = op_name.split(FLATTEN_SCOPE_PREFIX)  # skip 1st char (/)
        for sub_key in op_key_list:
            mo = re.match(r'^{}(\d+){}$'.format(FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE), sub_key)
            if mo:
                op_type = list
                idx = int(mo.group(1))
            else:
                op_type = OrderedDict
                idx = sub_key

            if current_structure is None:
                if base_structure is None:
                    base_structure = [None] if op_type == list else DataOpDict()
                current_structure = base_structure
            elif parent_key is not None:
                if (isinstance(parent_structure, list) and (parent_structure[parent_key] is None)) or \
                        (isinstance(parent_structure, DataOpDict) and parent_key not in parent_structure):
                    current_structure = [None] if op_type == list else DataOpDict()
                    parent_structure[parent_key] = current_structure
                else:
                    current_structure = parent_structure[parent_key]
                    if op_type == list and len(current_structure) == idx:
                        current_structure.append(None)

            parent_structure = current_structure
            parent_key = idx
            if isinstance(parent_structure, list) and len(parent_structure) == parent_key:
                parent_structure.append(None)

        if op_type == list and len(current_structure) == parent_key:
            current_structure.append(None)
        current_structure[parent_key] = op_val

    # Deep conversion from list to tuple.
    # TODO necessary in define by run?
    return deep_tuple(base_structure)
Esempio n. 11
0
        def _graph_fn_update_from_external_batch(root, preprocessed_states,
                                                 actions, rewards, terminals,
                                                 sequence_indices,
                                                 apply_postprocessing):
            """
            Calls iterative optimization by repeatedly sub-sampling.
            """
            multi_gpu_sync_optimizer = root.sub_components.get(
                "multi-gpu-synchronizer")

            # Return values.
            loss, loss_per_item, vf_loss, vf_loss_per_item = None, None, None, None

            policy = root.get_sub_component_by_name(agent.policy.scope)
            value_function = root.get_sub_component_by_name(
                agent.value_function.scope)
            optimizer = root.get_sub_component_by_name(agent.optimizer.scope)
            loss_function = root.get_sub_component_by_name(
                agent.loss_function.scope)
            value_function_optimizer = root.get_sub_component_by_name(
                agent.value_function_optimizer.scope)
            vars_merger = root.get_sub_component_by_name(
                agent.vars_merger.scope)
            gae_function = root.get_sub_component_by_name(
                agent.gae_function.scope)
            prev_log_probs = policy.get_log_likelihood(
                preprocessed_states, actions)["log_likelihood"]

            if get_backend() == "tf":
                # Log probs before update.
                prev_log_probs = tf.stop_gradient(prev_log_probs)
                batch_size = tf.shape(preprocessed_states)[0]
                prior_baseline_values = tf.stop_gradient(
                    value_function.value_output(preprocessed_states))

                # Advantages are based on prior baseline values.
                advantages = tf.cond(
                    pred=apply_postprocessing,
                    true_fn=lambda: gae_function.calc_gae_values(
                        prior_baseline_values, rewards, terminals,
                        sequence_indices),
                    false_fn=lambda: rewards)

                if self.standardize_advantages:
                    mean, std = tf.nn.moments(x=advantages, axes=[0])
                    advantages = (advantages - mean) / std

                def opt_body(index_, loss_, loss_per_item_, vf_loss_,
                             vf_loss_per_item_):
                    start = tf.random_uniform(shape=(),
                                              minval=0,
                                              maxval=batch_size - 1,
                                              dtype=tf.int32)
                    indices = tf.range(
                        start=start,
                        limit=start + agent.sample_size) % batch_size
                    sample_states = tf.gather(params=preprocessed_states,
                                              indices=indices)
                    if isinstance(actions, ContainerDataOp):
                        sample_actions = FlattenedDataOp()
                        for name, action in flatten_op(actions).items():
                            sample_actions[name] = tf.gather(params=action,
                                                             indices=indices)
                        sample_actions = unflatten_op(sample_actions)
                    else:
                        sample_actions = tf.gather(params=actions,
                                                   indices=indices)

                    sample_prior_log_probs = tf.gather(params=prev_log_probs,
                                                       indices=indices)
                    sample_rewards = tf.gather(params=rewards, indices=indices)
                    sample_terminals = tf.gather(params=terminals,
                                                 indices=indices)
                    sample_sequence_indices = tf.gather(
                        params=sequence_indices, indices=indices)
                    sample_advantages = tf.gather(params=advantages,
                                                  indices=indices)
                    sample_advantages.set_shape((self.sample_size, ))

                    sample_baseline_values = value_function.value_output(
                        sample_states)
                    sample_prior_baseline_values = tf.gather(
                        params=prior_baseline_values, indices=indices)

                    # If we are a multi-GPU root:
                    # Simply feeds everything into the multi-GPU sync optimizer's method and return.
                    if multi_gpu_sync_optimizer is not None:
                        main_policy_vars = agent.policy.variables()
                        main_vf_vars = agent.value_function.variables()
                        all_vars = agent.vars_merger.merge(
                            main_policy_vars, main_vf_vars)
                        # grads_and_vars, loss, loss_per_item, vf_loss, vf_loss_per_item = \
                        out = multi_gpu_sync_optimizer.calculate_update_from_external_batch(
                            all_vars,
                            sample_states,
                            sample_actions,
                            sample_rewards,
                            sample_terminals,
                            sample_sequence_indices,
                            apply_postprocessing=apply_postprocessing)
                        avg_grads_and_vars_policy, avg_grads_and_vars_vf = agent.vars_splitter.call(
                            out["avg_grads_and_vars_by_component"])
                        policy_step_op = agent.optimizer.apply_gradients(
                            avg_grads_and_vars_policy)
                        vf_step_op = agent.value_function_optimizer.apply_gradients(
                            avg_grads_and_vars_vf)
                        step_op = root._graph_fn_group(policy_step_op,
                                                       vf_step_op)
                        step_and_sync_op = multi_gpu_sync_optimizer.sync_variables_to_towers(
                            step_op, all_vars)
                        loss_vf, loss_per_item_vf = out[
                            "additional_return_0"], out["additional_return_1"]

                        # Have to set all shapes here due to strict loop-var shape requirements.
                        out["loss"].set_shape(())
                        loss_vf.set_shape(())
                        loss_per_item_vf.set_shape((agent.sample_size, ))
                        out["loss_per_item"].set_shape((agent.sample_size, ))

                        with tf.control_dependencies([step_and_sync_op]):
                            if index_ == 0:
                                # Increase the global training step counter.
                                out["loss"] = root._graph_fn_training_step(
                                    out["loss"])
                            return index_ + 1, out["loss"], out[
                                "loss_per_item"], loss_vf, loss_per_item_vf

                    policy_probs = policy.get_log_likelihood(
                        sample_states, sample_actions)["log_likelihood"]
                    baseline_values = value_function.value_output(
                        tf.stop_gradient(sample_states))
                    sample_rewards = tf.cond(
                        pred=apply_postprocessing,
                        true_fn=lambda: gae_function.calc_gae_values(
                            baseline_values, sample_rewards, sample_terminals,
                            sample_sequence_indices),
                        false_fn=lambda: sample_rewards)
                    sample_rewards.set_shape((agent.sample_size, ))
                    entropy = policy.get_entropy(sample_states)["entropy"]

                    loss, loss_per_item, vf_loss, vf_loss_per_item = \
                        loss_function.loss(
                            policy_probs, sample_prior_log_probs,
                            sample_baseline_values, sample_prior_baseline_values, sample_advantages, entropy
                        )

                    if hasattr(root, "is_multi_gpu_tower"
                               ) and root.is_multi_gpu_tower is True:
                        policy_grads_and_vars = optimizer.calculate_gradients(
                            policy.variables(), loss)
                        vf_grads_and_vars = value_function_optimizer.calculate_gradients(
                            value_function.variables(), vf_loss)
                        grads_and_vars_by_component = vars_merger.merge(
                            policy_grads_and_vars, vf_grads_and_vars)
                        return grads_and_vars_by_component, loss, loss_per_item, vf_loss, vf_loss_per_item
                    else:
                        step_op, loss, loss_per_item = optimizer.step(
                            policy.variables(), loss, loss_per_item)
                        loss.set_shape(())
                        loss_per_item.set_shape((agent.sample_size, ))

                        vf_step_op, vf_loss, vf_loss_per_item = value_function_optimizer.step(
                            value_function.variables(), vf_loss,
                            vf_loss_per_item)
                        vf_loss.set_shape(())
                        vf_loss_per_item.set_shape((agent.sample_size, ))

                        with tf.control_dependencies([step_op, vf_step_op]):
                            return index_ + 1, loss, loss_per_item, vf_loss, vf_loss_per_item

                def cond(index_, loss_, loss_per_item_, v_loss_,
                         v_loss_per_item_):
                    return index_ < agent.iterations

                init_loop_vars = [
                    0,
                    tf.zeros(shape=(), dtype=tf.float32),
                    tf.zeros(shape=(agent.sample_size, )),
                    tf.zeros(shape=(), dtype=tf.float32),
                    tf.zeros(shape=(agent.sample_size, ))
                ]

                if hasattr(root, "is_multi_gpu_tower"
                           ) and root.is_multi_gpu_tower is True:
                    return opt_body(*init_loop_vars)
                else:
                    index, loss, loss_per_item, vf_loss, vf_loss_per_item = tf.while_loop(
                        cond=cond,
                        body=opt_body,
                        loop_vars=init_loop_vars,
                        parallel_iterations=1)
                    # Increase the global training step counter.
                    loss = root._graph_fn_training_step(loss)
                    return loss, loss_per_item, vf_loss, vf_loss_per_item

            elif get_backend() == "pytorch":
                if isinstance(prev_log_probs, dict):
                    for name in actions.keys():
                        prev_log_probs[name] = prev_log_probs[name].detach()
                else:
                    prev_log_probs = prev_log_probs.detach()
                batch_size = preprocessed_states.shape[0]
                sample_size = min(batch_size, agent.sample_size)
                prior_baseline_values = value_function.value_output(
                    preprocessed_states).detach()
                if apply_postprocessing:
                    advantages = gae_function.calc_gae_values(
                        prior_baseline_values, rewards, terminals,
                        sequence_indices)
                else:
                    advantages = rewards
                if self.standardize_advantages:
                    advantages = (advantages - torch.mean(advantages)
                                  ) / torch.std(advantages)

                for _ in range(agent.iterations):
                    start = int(torch.rand(1) * (batch_size - 1))
                    indices = torch.arange(start=start,
                                           end=start + sample_size,
                                           dtype=torch.long) % batch_size
                    sample_states = torch.index_select(preprocessed_states, 0,
                                                       indices)

                    if isinstance(actions, dict):
                        sample_actions = DataOpDict()
                        sample_prior_log_probs = DataOpDict()
                        for name, action in define_by_run_flatten(
                                actions,
                                scope_separator_at_start=False).items():
                            sample_actions[name] = torch.index_select(
                                action, 0, indices)
                            sample_prior_log_probs[name] = torch.index_select(
                                prev_log_probs[name], 0, indices)
                    else:
                        sample_actions = torch.index_select(
                            actions, 0, indices)
                        sample_prior_log_probs = torch.index_select(
                            prev_log_probs, 0, indices)

                    sample_advantages = torch.index_select(
                        advantages, 0, indices)
                    sample_prior_baseline_values = torch.index_select(
                        prior_baseline_values, 0, indices)

                    policy_probs = policy.get_log_likelihood(
                        sample_states, sample_actions)["log_likelihood"]
                    sample_baseline_values = value_function.value_output(
                        sample_states)

                    entropy = policy.get_entropy(sample_states)["entropy"]
                    loss, loss_per_item, vf_loss, vf_loss_per_item = loss_function.loss(
                        policy_probs, sample_prior_log_probs,
                        sample_baseline_values, sample_prior_baseline_values,
                        sample_advantages, entropy)

                    # Do not need step op.
                    _, loss, loss_per_item = optimizer.step(
                        policy.variables(), loss, loss_per_item)
                    _, vf_loss, vf_loss_per_item = \
                        value_function_optimizer.step(value_function.variables(), vf_loss, vf_loss_per_item)
                return loss, loss_per_item, vf_loss, vf_loss_per_item
Esempio n. 12
0
                def opt_body(index_, loss_, loss_per_item_, vf_loss_,
                             vf_loss_per_item_):
                    start = tf.random_uniform(shape=(),
                                              minval=0,
                                              maxval=batch_size - 1,
                                              dtype=tf.int32)
                    indices = tf.range(
                        start=start,
                        limit=start + agent.sample_size) % batch_size
                    sample_states = tf.gather(params=preprocessed_states,
                                              indices=indices)
                    if isinstance(actions, dict):
                        sample_actions = DataOpDict()
                        sample_prior_log_probs = DataOpDict()
                        for name, action in flatten_op(
                                actions,
                                scope_separator_at_start=False).items():
                            sample_actions[name] = tf.gather(params=action,
                                                             indices=indices)
                            sample_prior_log_probs[name] = tf.gather(
                                params=prev_log_probs[name], indices=indices)
                    else:
                        sample_actions = tf.gather(params=actions,
                                                   indices=indices)
                        sample_prior_log_probs = tf.gather(
                            params=prev_log_probs, indices=indices)

                    sample_rewards = tf.gather(params=rewards, indices=indices)
                    sample_terminals = tf.gather(params=terminals,
                                                 indices=indices)
                    sample_sequence_indices = tf.gather(
                        params=sequence_indices, indices=indices)

                    # If we are a multi-GPU root:
                    # Simply feeds everything into the multi-GPU sync optimizer's method and return.
                    if multi_gpu_sync_optimizer is not None:
                        main_policy_vars = agent.policy.variables()
                        main_vf_vars = agent.value_function.variables()
                        all_vars = agent.vars_merger.merge(
                            main_policy_vars, main_vf_vars)
                        # grads_and_vars, loss, loss_per_item, vf_loss, vf_loss_per_item = \
                        out = multi_gpu_sync_optimizer.calculate_update_from_external_batch(
                            all_vars,
                            sample_states,
                            sample_actions,
                            sample_rewards,
                            sample_terminals,
                            sample_sequence_indices,
                            apply_postprocessing=apply_postprocessing)
                        avg_grads_and_vars_policy, avg_grads_and_vars_vf = agent.vars_splitter.split(
                            out["avg_grads_and_vars_by_component"])
                        policy_step_op = agent.optimizer.apply_gradients(
                            avg_grads_and_vars_policy)
                        vf_step_op = agent.value_function_optimizer.apply_gradients(
                            avg_grads_and_vars_vf)
                        step_op = root._graph_fn_group(policy_step_op,
                                                       vf_step_op)
                        step_and_sync_op = multi_gpu_sync_optimizer.sync_variables_to_towers(
                            step_op, all_vars)
                        loss_vf, loss_per_item_vf = out[
                            "additional_return_0"], out["additional_return_1"]

                        # Have to set all shapes here due to strict loop-var shape requirements.
                        out["loss"].set_shape(())
                        loss_vf.set_shape(())
                        loss_per_item_vf.set_shape((agent.sample_size, ))
                        out["loss_per_item"].set_shape((agent.sample_size, ))

                        with tf.control_dependencies([step_and_sync_op]):
                            if index_ == 0:
                                # Increase the global training step counter.
                                out["loss"] = root._graph_fn_training_step(
                                    out["loss"])
                            return index_ + 1, out["loss"], out[
                                "loss_per_item"], loss_vf, loss_per_item_vf

                    policy_probs = policy.get_action_log_probs(
                        sample_states, sample_actions)
                    baseline_values = value_function.value_output(
                        tf.stop_gradient(sample_states))
                    sample_rewards = tf.cond(
                        pred=apply_postprocessing,
                        true_fn=lambda: gae_function.calc_gae_values(
                            baseline_values, sample_rewards, sample_terminals,
                            sample_sequence_indices),
                        false_fn=lambda: sample_rewards)
                    sample_rewards.set_shape((agent.sample_size, ))
                    entropy = policy.get_entropy(sample_states)["entropy"]

                    loss, loss_per_item, vf_loss, vf_loss_per_item = \
                        loss_function.loss(
                            policy_probs["action_log_probs"], sample_prior_log_probs,
                            baseline_values, sample_rewards,  entropy
                        )

                    if hasattr(root, "is_multi_gpu_tower"
                               ) and root.is_multi_gpu_tower is True:
                        policy_grads_and_vars = optimizer.calculate_gradients(
                            policy.variables(), loss)
                        vf_grads_and_vars = value_function_optimizer.calculate_gradients(
                            value_function.variables(), vf_loss)
                        grads_and_vars_by_component = vars_merger.merge(
                            policy_grads_and_vars, vf_grads_and_vars)
                        return grads_and_vars_by_component, loss, loss_per_item, vf_loss, vf_loss_per_item
                    else:
                        step_op, loss, loss_per_item = optimizer.step(
                            policy.variables(), loss, loss_per_item)
                        loss.set_shape(())
                        loss_per_item.set_shape((agent.sample_size, ))

                        vf_step_op, vf_loss, vf_loss_per_item = value_function_optimizer.step(
                            value_function.variables(), vf_loss,
                            vf_loss_per_item)
                        vf_loss.set_shape(())
                        vf_loss_per_item.set_shape((agent.sample_size, ))

                        with tf.control_dependencies([step_op, vf_step_op]):
                            return index_ + 1, loss, loss_per_item, vf_loss, vf_loss_per_item