Exemple #1
0
    def __init__(self, component, num_op_records=None, args=None, kwargs=None):
        """
        Args:
            component (Component): The Component to which this column belongs.
        """
        self.id = self.get_id()

        if num_op_records is None:
            self.op_records = []
            if args is not None:
                for i in range(len(args)):
                    if args[i] is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, position=i)
                    # If incoming is an op-rec -> Link them.
                    if isinstance(args[i], DataOpRecord):
                        op_rec.previous = args[i]
                        op = args[i].op
                        if op is not None:
                            op_rec.op = op
                            op_rec.space = get_space_from_op(op)
                        args[i].next.add(op_rec)
                    # Do constant value assignment here.
                    elif args[i] is not None:
                        op = args[i]
                        if is_constant(op) and not isinstance(op, np.ndarray):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)
                    self.op_records.append(op_rec)

            if kwargs is not None:
                for key in sorted(kwargs.keys()):
                    value = kwargs[key]
                    if value is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, kwarg=key)
                    # If incoming is an op-rec -> Link them.
                    if isinstance(value, DataOpRecord):
                        op_rec.previous = value
                        op_rec.op = value.op  # assign op if any
                        value.next.add(op_rec)
                    # Do constant value assignment here.
                    elif value is not None:
                        op = value
                        if is_constant(op):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)
                    self.op_records.append(op_rec)
        else:
            self.op_records = [DataOpRecord(op=None, column=self, position=i) for i in range(num_op_records)]

        # For __str__ purposes.
        self.op_id_list = [o.id for o in self.op_records]

        # The component this column belongs to.
        self.component = component
Exemple #2
0
    def __init__(self, file_name=None, worker_id=0, base_port=5005, seed=0, docker_training=False, no_graphics=False,
                 timeout_wait=30, train_mode=True, **kwargs):
        """
        Args:
            file_name (Optional[str]): Name of Unity environment binary.
            base_port (int): Port number to connect to Unity environment. `worker_id` increments on top of this.
            worker_id (int): Number to add to `base_port`. Used for asynchronous agent scenarios.
            docker_training (bool): Informs this class, whether the process is being run within a container.
                Default: False.
            no_graphics (bool): Whether to run the Unity simulator in no-graphics mode. Default: False.
            timeout_wait (int): Time (in seconds) to wait for connection from environment.
            train_mode (bool): Whether to run in training mode, speeding up the simulation. Default: True.
        """
        # First create the UnityMLAgentsEnvironment to get state and action spaces, then create RLgraph Environment
        # instance.
        self.mlagents_env = UnityEnvironment(
            file_name, worker_id, base_port, seed, docker_training, no_graphics
        )
        all_brain_info = self.mlagents_env.reset()
        # Get all possible information from AllBrainInfo.
        # TODO: Which scene do we pick?
        self.scene_key = next(iter(all_brain_info))
        first_brain_info = all_brain_info[self.scene_key]
        num_environments = len(first_brain_info.agents)

        state_space = {}
        if len(first_brain_info.vector_observations[0]) > 0:
            state_space["vector"] = get_space_from_op(first_brain_info.vector_observations[0])
            # TODO: This is a hack.
            if state_space["vector"].dtype == np.float64:
                state_space["vector"].dtype = np.float32
        if len(first_brain_info.visual_observations) > 0:
            state_space["visual"] = get_space_from_op(first_brain_info.visual_observations[0])
        if first_brain_info.text_observations[0]:
            state_space["text"] = get_space_from_op(first_brain_info.text_observations[0])

        if len(state_space) == 1:
            self.state_key = next(iter(state_space))
            state_space = state_space[self.state_key]
        else:
            self.state_key = None
            state_space = Dict(state_space)
        action_space = get_space_from_op(first_brain_info.action_masks[0])
        if action_space.dtype == np.float64:
            action_space.dtype = np.float32

        super(MLAgentsEnv, self).__init__(
            num_environments=num_environments, state_space=state_space, action_space=action_space, **kwargs
        )

        # Caches the last observation we made (after stepping or resetting).
        self.last_state = []
 def _graph_fn_call(self, inputs):
     if self.backend == "python" or get_backend() == "python":
         if isinstance(inputs, list):
             inputs = np.asarray(inputs)
         return inputs.astype(
             dtype=util.convert_dtype(self.to_dtype, to="np"))
     elif get_backend() == "pytorch":
         torch_dtype = util.convert_dtype(self.to_dtype, to="pytorch")
         if torch_dtype == torch.float or torch.float32:
             return inputs.float()
         elif torch_dtype == torch.int or torch.int32:
             return inputs.int()
         elif torch_dtype == torch.uint8:
             return inputs.byte()
     elif get_backend() == "tf":
         in_space = get_space_from_op(inputs)
         to_dtype = util.convert_dtype(self.to_dtype, to="tf")
         if inputs.dtype != to_dtype:
             ret = tf.cast(x=inputs, dtype=to_dtype)
             if in_space.has_batch_rank is True:
                 ret._batch_rank = 0 if in_space.time_major is False else 1
             if in_space.has_time_rank is True:
                 ret._time_rank = 0 if in_space.time_major is True else 1
             return ret
         else:
             return inputs
Exemple #4
0
def define_api_method(component, api_method_record, copy_record=True):
    """
    Registers an API-method with a Component instance.

    Args:
        component (Component): The Component object to register the API method with.
        api_method_record (APIMethodRecord): The APIMethodRecord describing the to-be-registered API-method.
        copy_record (bool): Whether to deepcopy the APIMethodRecord prior to handing it to the Component for storing.
    """
    # Deep copy the record (in case this got registered the normal way with via decorating a class method).
    if copy_record:
        api_method_record = copy.deepcopy(api_method_record)
    api_method_record.component = component

    # Raise errors if `name` already taken in this Component.
    if not api_method_record.ok_to_overwrite:
        # There already is an API-method with that name.
        if api_method_record.name in component.api_methods:
            raise RLGraphError(
                "API-method with name '{}' already defined!".format(
                    api_method_record.name))
        # There already is another object property with that name (avoid accidental overriding).
        elif not api_method_record.is_class_method and getattr(
                component, api_method_record.name, None) is not None:
            raise RLGraphError(
                "Component '{}' already has a property called '{}'. Cannot define an API-method with "
                "the same name!".format(component.name,
                                        api_method_record.name))

    # Do not build this API as per ctor instructions.
    if api_method_record.name in component.switched_off_apis:
        return

    component.synthetic_methods.add(api_method_record.name)
    setattr(
        component, api_method_record.name,
        api_method_record.wrapper_func.__get__(component, component.__class__))
    setattr(api_method_record.wrapper_func, "__name__", api_method_record.name)

    component.api_methods[api_method_record.name] = api_method_record

    # Direct callable for eager/define by run.
    component.api_fn_by_name[
        api_method_record.name] = api_method_record.wrapper_func

    # Update the api_method_inputs dict (with empty Spaces if not defined yet).
    skip_args = 1  # self
    skip_args += (api_method_record.is_graph_fn_wrapper
                  and api_method_record.add_auto_key_as_first_param)
    param_list = list(
        inspect.signature(
            api_method_record.func).parameters.values())[skip_args:]

    for param in param_list:
        component.api_methods[api_method_record.name].input_names.append(
            param.name)
        if param.name not in component.api_method_inputs:
            # This param has a default value.
            if param.default != inspect.Parameter.empty:
                # Default is None. Set to "flex" (to signal that this Space is not needed for input-completeness)
                # and wait for first call using this parameter (only then set it to that Space).
                if param.default is None:
                    component.api_method_inputs[param.name] = "flex"
                # Default is some python value (e.g. a bool). Use that are the assigned Space.
                else:
                    space = get_space_from_op(param.default)
                    component.api_method_inputs[param.name] = space
            # This param is an *args param. Store as "*flex". Then with upcoming API calls, we determine the Spaces
            # for the single items in *args and set them under "param[0]", "param[1]", etc..
            elif param.kind == inspect.Parameter.VAR_POSITIONAL:
                component.api_method_inputs[param.name] = "*flex"
            # This param is a **kwargs param. Store as "**flex". Then with upcoming API calls, we determine the Spaces
            # for the single items in **kwargs and set them under "param[some-key]", "param[some-other-key]", etc..
            elif param.kind == inspect.Parameter.VAR_KEYWORD:
                component.api_method_inputs[param.name] = "**flex"
            # Normal POSITIONAL_ONLY parameter. Store as None (needed) for now.
            else:
                component.api_method_inputs[param.name] = None
Exemple #5
0
        def api_method_wrapper(self, *args, **kwargs):
            api_fn_name = name or re.sub(r'^_graph_fn_', "",
                                         wrapped_func.__name__)
            # Direct evaluation of function.
            if self.execution_mode == "define_by_run":
                type(self).call_count += 1

                start = time.perf_counter()
                # Check with owner if extra args needed.
                if api_fn_name in self.api_methods and self.api_methods[
                        api_fn_name].add_auto_key_as_first_param:
                    output = wrapped_func(self, "", *args, **kwargs)
                else:
                    output = wrapped_func(self, *args, **kwargs)

                # Store runtime for this method.
                type(self).call_times.append(  # Component.call_times
                    (self.name, wrapped_func.__name__,
                     time.perf_counter() - start))
                return output

            api_method_rec = self.api_methods[api_fn_name]

            # Sanity check input args for accidential dict-return values being passed into the next API as
            # supposed DataOpRecord.
            dict_args = [
                next(iter(a.values())) for a in args if isinstance(a, dict)
            ]
            if len(dict_args) > 0 and isinstance(dict_args[0], DataOpRecord):
                raise RLGraphError(
                    "One of your input args to API-method '{}.{}()' is a dict of DataOpRecords! This is probably "
                    "coming from a previous call to an API-method (returning a dict) and the DataOpRecord should be "
                    "extracted by string-key and passed into '{}' "
                    "directly.".format(api_method_rec.component.global_scope,
                                       api_fn_name, api_fn_name))
            # Create op-record column to call API method with. Ignore None input params. These should not be sent
            # to the API-method.
            in_op_column = DataOpRecordColumnIntoAPIMethod(
                component=self,
                api_method_rec=api_method_rec,
                args=args,
                kwargs=kwargs)
            # Add the column to the API-method record.
            api_method_rec.in_op_columns.append(in_op_column)

            # Check minimum number of passed args.
            minimum_num_call_params = len(in_op_column.api_method_rec.non_args_kwargs) - \
                len(in_op_column.api_method_rec.default_args)
            if len(in_op_column.op_records) < minimum_num_call_params:
                raise RLGraphAPICallParamError(
                    "Number of call params ({}) for call to API-method '{}' is too low. Needs to be at least {} "
                    "params!".format(len(in_op_column.op_records),
                                     api_method_rec.name,
                                     minimum_num_call_params))

            # Link from incoming op_recs into the new column or populate new column with ops/Spaces (this happens
            # if this call was made from within a graph_fn such that ops and Spaces are already known).
            all_args = [(i, a) for i, a in enumerate(args) if a is not None] + \
                       [(k, v) for k, v in sorted(kwargs.items()) if v is not None]
            flex = None
            build_when_done = False
            for i, (key, value) in enumerate(all_args):
                # Named arg/kwarg -> get input_name from that and peel op_rec.
                if isinstance(key, str):
                    param_name = key
                # Positional arg -> get input_name from input_names list.
                else:
                    slot = key if flex is None else flex
                    if slot >= len(api_method_rec.input_names):
                        raise RLGraphAPICallParamError(
                            "Too many input args given in call to API-method '{}'!"
                            .format(api_method_rec.name))
                    param_name = api_method_rec.input_names[slot]

                # Var-positional arg, attach the actual position to input_name string.
                if self.api_method_inputs.get(param_name, "") == "*flex":
                    if flex is None:
                        flex = i
                    param_name += "[{}]".format(i - flex)
                # Actual kwarg (not in list of api_method_inputs).
                elif api_method_rec.kwargs_name is not None and param_name not in self.api_method_inputs:
                    param_name = api_method_rec.kwargs_name + "[{}]".format(
                        param_name)

                # We are already in building phase (params may be coming from inside graph_fn).
                if self.graph_builder is not None and self.graph_builder.phase == "building":
                    # If Space not stored yet, determine it from op.
                    assert in_op_column.op_records[i].op is not None
                    if in_op_column.op_records[i].space is None:
                        in_op_column.op_records[i].space = get_space_from_op(
                            in_op_column.op_records[i].op)
                    self.api_method_inputs[
                        param_name] = in_op_column.op_records[i].space
                    # Check input-completeness of Component (but not strict as we are only calling API, not a graph_fn).
                    if self.input_complete is False:
                        # Build right after this loop in case more Space information comes in through next args/kwargs.
                        build_when_done = True

                # A DataOpRecord from the meta-graph.
                elif isinstance(value, DataOpRecord):
                    # Create entry with unknown Space if it doesn't exist yet.
                    if param_name not in self.api_method_inputs:
                        self.api_method_inputs[param_name] = None

                # Fixed value (instead of op-record): Store the fixed value directly in the op.
                else:
                    if self.api_method_inputs.get(param_name) is None:
                        self.api_method_inputs[
                            param_name] = in_op_column.op_records[i].space

            if build_when_done:
                # Check Spaces and create variables.
                self.graph_builder.build_component_when_input_complete(self)

            # Regular API-method: Call it here.
            api_fn_args, api_fn_kwargs = in_op_column.get_args_and_kwargs()

            if api_method_rec.is_graph_fn_wrapper is False:
                return_values = wrapped_func(self, *api_fn_args,
                                             **api_fn_kwargs)
            # Wrapped graph_fn: Call it through yet another wrapper.
            else:
                return_values = graph_fn_wrapper(
                    self, wrapped_func, returns,
                    dict(
                        flatten_ops=flatten_ops,
                        split_ops=split_ops,
                        add_auto_key_as_first_param=add_auto_key_as_first_param,
                        requires_variable_completeness=
                        requires_variable_completeness), *api_fn_args,
                    **api_fn_kwargs)

            # Process the results (push into a column).
            out_op_column = DataOpRecordColumnFromAPIMethod(
                component=self,
                api_method_name=api_fn_name,
                args=util.force_tuple(return_values)
                if type(return_values) != dict else None,
                kwargs=return_values if type(return_values) == dict else None)

            # If we already have actual op(s) and Space(s), push them already into the
            # DataOpRecordColumnFromAPIMethod's records.
            if self.graph_builder is not None and self.graph_builder.phase == "building":
                # Link the returned ops to that new out-column.
                for i, rec in enumerate(out_op_column.op_records):
                    out_op_column.op_records[i].op = rec.op
                    out_op_column.op_records[i].space = rec.space
            # And append the new out-column to the api-method-rec.
            api_method_rec.out_op_columns.append(out_op_column)

            # Do we need to return the raw ops or the op-recs?
            # Only need to check if False, otherwise, we return ops directly anyway.
            return_ops = False
            stack = inspect.stack()
            f_locals = stack[1][0].f_locals
            # We may be in a list comprehension, try next frame.
            if f_locals.get(".0"):
                f_locals = stack[2][0].f_locals
            # Check whether the caller component is a parent of this one.
            caller_component = f_locals.get(
                "root", f_locals.get("self_", f_locals.get("self")))

            # Potential call from a lambda.
            if caller_component is None and "fn" in stack[2][0].f_locals:
                # This is the component.
                prev_caller_component = TraceContext.PREV_CALLER
                lambda_obj = stack[2][0].f_locals["fn"]
                if "lambda" in inspect.getsource(lambda_obj):
                    # Try to reconstruct caller by using parent of prior caller.
                    caller_component = prev_caller_component.parent_component

            if caller_component is None:
                raise RLGraphError(
                    "API-method '{}' must have as 1st parameter (the component) either `root` or `self`. Other names "
                    "are not allowed!".format(api_method_rec.name))
            # Not directly called by this method itself (auto-helper-component-API-call).
            # AND call is coming from some caller Component, but that component is not this component
            # OR a parent -> Error.
            elif caller_component is not None and \
                    type(caller_component).__name__ != "MetaGraphBuilder" and \
                    caller_component not in [self] + self.get_parents():
                if not (stack[1][3] == "__init__"
                        and re.search(r'op_records\.py$', stack[1][1])):
                    raise RLGraphError(
                        "The component '{}' is not a child (or grand-child) of the caller ({})! Maybe you forgot to "
                        "add it as a sub-component via `add_components()`.".
                        format(self.global_scope,
                               caller_component.global_scope))

            # Update trace context.
            TraceContext.PREV_CALLER = caller_component

            for stack_item in stack[1:]:  # skip current frame
                # If we hit an API-method call -> return op-recs.
                if stack_item[3] == "api_method_wrapper" and re.search(
                        r'decorators\.py$', stack_item[1]):
                    break
                # If we hit a graph_fn call -> return ops.
                elif stack_item[3] == "run_through_graph_fn" and re.search(
                        r'graph_builder\.py$', stack_item[1]):
                    return_ops = True
                    break

            if return_ops is True:
                if type(return_values) == dict:
                    return {
                        key: value.op
                        for key, value in out_op_column.get_args_and_kwargs()
                        [1].items()
                    }
                else:
                    tuple_returns = tuple(
                        map(lambda x: x.op,
                            out_op_column.get_args_and_kwargs()[0]))
                    return tuple_returns[0] if len(
                        tuple_returns) == 1 else tuple_returns
            # Parent caller is non-graph_fn: Return op-recs.
            else:
                if type(return_values) == dict:
                    return return_values
                else:
                    tuple_returns = out_op_column.get_args_and_kwargs()[0]
                    return tuple_returns[0] if len(
                        tuple_returns) == 1 else tuple_returns
Exemple #6
0
    def _graph_fn_apply(self,
                        key,
                        preprocessing_inputs,
                        input_before_time_rank_folding=None):
        """
        Reshapes the input to the specified new shape.

        Args:
            preprocessing_inputs (SingleDataOp): The input to reshape.
            input_before_time_rank_folding (Optional[SingleDataOp]): The original input (before!) the time-rank had
                been folded (this was done in a different ReShape Component). Serves if `self.unfold_time_rank` is True
                to figure out the exact time-rank dimension to unfold.

        Returns:
            SingleDataOp: The reshaped input.
        """
        assert self.unfold_time_rank is False or input_before_time_rank_folding is not None

        if self.backend == "python" or get_backend() == "python":
            # Create a one-hot axis for the categories at the end?
            num_categories = self.get_num_categories(
                key, get_space_from_op(preprocessing_inputs))
            if num_categories and num_categories > 1:
                preprocessing_inputs = one_hot(preprocessing_inputs,
                                               depth=num_categories)

            if self.unfold_time_rank:
                new_shape = (-1, -1) + preprocessing_inputs.shape[1:]
            elif self.fold_time_rank:
                new_shape = (-1, ) + preprocessing_inputs.shape[2:]
            else:
                new_shape = self.get_preprocessed_space(
                    get_space_from_op(preprocessing_inputs)).get_shape(
                        with_batch_rank=-1, with_time_rank=-1)

            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            if len(preprocessing_inputs.shape
                   ) > 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = input_before_time_rank_folding.shape
                    new_shape = (original_shape[0],
                                 original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = preprocessing_inputs.shape
                    # Batch and time rank stay as is.
                    new_shape = (input_shape[0],
                                 input_shape[1]) + new_shape[2:]

            return np.reshape(preprocessing_inputs, newshape=new_shape)

        elif get_backend() == "pytorch":
            # Create a one-hot axis for the categories at the end?
            num_categories = self.get_num_categories(
                key, get_space_from_op(preprocessing_inputs))
            if num_categories and num_categories > 1:
                preprocessing_inputs = pytorch_one_hot(preprocessing_inputs,
                                                       depth=num_categories)

            if self.unfold_time_rank:
                new_shape = (-1, -1) + preprocessing_inputs.shape[1:]
            elif self.fold_time_rank:
                new_shape = (-1, ) + preprocessing_inputs.shape[2:]
            else:
                new_shape = self.get_preprocessed_space(
                    get_space_from_op(preprocessing_inputs)).get_shape(
                        with_batch_rank=-1, with_time_rank=-1)

            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            if len(new_shape
                   ) > 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = input_before_time_rank_folding.shape
                    new_shape = (original_shape[0],
                                 original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = preprocessing_inputs.shape
                    # Batch and time rank stay as is.
                    new_shape = (input_shape[0],
                                 input_shape[1]) + new_shape[2:]

            # print("Reshaping input of shape {} to new shape {} (flatten = {})".format(preprocessing_inputs.shape,
            #                                                                           new_shape, self.flatten))

            old_size = np.prod(list(preprocessing_inputs.shape))
            new_size = np.prod(new_shape)

            # The problem here is the following: Input has dim e.g. [4, 256, 1, 1]
            # -> If shape inference in spaces failed, output dim is not correct -> reshape will attempt
            # something like reshaping to [256].
            if self.flatten and preprocessing_inputs.dim() > 1:
                flattened_shape_without_batchrank = np.prod(
                    preprocessing_inputs.shape[1:])
                flattened_shape = (preprocessing_inputs.shape[0], ) + (
                    flattened_shape_without_batchrank, )
                return torch.reshape(preprocessing_inputs, flattened_shape)
            # If new shape does not fit into old shape, batch inference failed -> try to restore:
            # Equal except batch rank -> return as is:
            elif old_size != new_size:
                if tuple(preprocessing_inputs.shape[1:]) == new_shape:
                    return preprocessing_inputs
                else:
                    # Attempt to rescue reshape by combining new shape with batch dim.
                    full_new_shape = (
                        preprocessing_inputs.shape[0], ) + new_shape
                    return torch.reshape(preprocessing_inputs, full_new_shape)
            else:
                return torch.reshape(preprocessing_inputs, new_shape)

        elif get_backend() == "tf":
            # Create a one-hot axis for the categories at the end?
            space = get_space_from_op(preprocessing_inputs)
            num_categories = self.get_num_categories(key, space)
            if num_categories and num_categories > 1:
                preprocessing_inputs_ = tf.one_hot(preprocessing_inputs,
                                                   depth=num_categories,
                                                   axis=-1,
                                                   dtype="float32")
                if hasattr(preprocessing_inputs, "_batch_rank"):
                    preprocessing_inputs_._batch_rank = preprocessing_inputs._batch_rank
                if hasattr(preprocessing_inputs, "_time_rank"):
                    preprocessing_inputs_._time_rank = preprocessing_inputs._time_rank
                preprocessing_inputs = preprocessing_inputs_

            if self.unfold_time_rank:
                list_shape = preprocessing_inputs.shape.as_list()
                assert len(list_shape) == 1 or list_shape[1] is not None,\
                    "ERROR: Cannot unfold. `preprocessing_inputs` (with shape {}) " \
                    "already seems to be unfolded!".format(list_shape)
                new_shape = (-1, -1) + tuple(list_shape[1:])
            elif self.fold_time_rank:
                new_shape = (-1, ) + tuple(
                    preprocessing_inputs.shape.as_list()[2:])
            else:
                new_shape = self.get_preprocessed_space(
                    get_space_from_op(preprocessing_inputs)).get_shape(
                        with_batch_rank=-1, with_time_rank=-1)

            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            if len(new_shape
                   ) >= 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = tf.shape(input_before_time_rank_folding)
                    new_shape = (original_shape[0],
                                 original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = tf.shape(preprocessing_inputs)
                    # Batch and time rank stay as is.
                    new_shape = (input_shape[0],
                                 input_shape[1]) + new_shape[2:]

            reshaped = tf.reshape(tensor=preprocessing_inputs,
                                  shape=new_shape,
                                  name="reshaped")

            # Have to place the time rank back in as unknown (for the auto Space inference).
            if type(self.unfold_time_rank) == int:
                # TODO: replace placeholder with default value by _batch_rank/_time_rank properties.
                return tf.placeholder_with_default(reshaped,
                                                   shape=(None, None) +
                                                   new_shape[2:])
            else:
                # TODO: add other cases of reshaping and fix batch/time rank hints.
                if self.fold_time_rank:
                    reshaped._batch_rank = 0
                elif self.unfold_time_rank:
                    reshaped._batch_rank = 1 if self.time_major is True else 0
                    reshaped._time_rank = 0 if self.time_major is True else 1
                else:
                    if space.has_batch_rank is True:
                        if space.time_major is False:
                            reshaped._batch_rank = 0
                        else:
                            reshaped._time_rank = 0
                            reshaped._batch_rank = 1
                    if space.has_time_rank is True:
                        reshaped._time_rank = 0 if space.time_major is True else 1

                return reshaped
Exemple #7
0
    def __init__(self, component, num_op_records=None, args=None, kwargs=None):
        """
        Args:
            component (Component): The Component to which this column belongs.
        """
        self.id = self.get_id()

        if num_op_records is None:
            self.op_records = []
            if args is not None:
                args = list(args)
                for i in range(len(args)):
                    if args[i] is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, position=i)

                    # Dict instead of a DataOpRecord -> Translate on the fly into a DataOpRec held by a
                    # ContainerMerger Component.
                    if isinstance(args[i], dict):
                        items = args[i].items()
                        keys = [k for k, _ in items]
                        values = [v for _, v in items]
                        if isinstance(values[0], DataOpRecord):
                            merger_component = values[0].column.component.get_helper_component(
                                "container-merger", _args=list(keys)
                            )
                            args[i] = merger_component.merge(*list(values))
                    # Tuple instead of a DataOpRecord -> Translate on the fly into a DataOpRec held by a
                    # ContainerMerger Component.
                    elif isinstance(args[i], tuple) and isinstance(args[i][0], DataOpRecord):
                        merger_component = args[i][0].column.component.get_helper_component(
                            "container-merger", _args=len(args[i])
                        )
                        args[i] = merger_component.merge(*args[i])

                    # If incoming is an op-rec -> Link them.
                    if isinstance(args[i], DataOpRecord):
                        args[i].connect_to(op_rec)
                    # Do constant value assignment here.
                    elif args[i] is not None:
                        op = args[i]
                        if is_constant(op) and not isinstance(op, np.ndarray):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)

                    self.op_records.append(op_rec)

            if kwargs is not None:
                for key in sorted(kwargs.keys()):
                    value = kwargs[key]
                    if value is None:
                        continue
                    op_rec = DataOpRecord(op=None, column=self, kwarg=key)
                    # If incoming is an op-rec -> Link them.
                    if isinstance(value, DataOpRecord):
                        op_rec.previous = value
                        op_rec.op = value.op  # assign op if any
                        value.next.add(op_rec)
                    # Do constant value assignment here.
                    elif value is not None:
                        op = value
                        if is_constant(op):
                            op = np.array(op, dtype=convert_dtype(type(op), "np"))
                        op_rec.op = op
                        op_rec.space = get_space_from_op(op)
                        component.constant_op_records.add(op_rec)
                    self.op_records.append(op_rec)
        else:
            self.op_records = [DataOpRecord(op=None, column=self, position=i) for i in range(num_op_records)]

        # For __str__ purposes.
        self.op_id_list = [o.id for o in self.op_records]

        # The component this column belongs to.
        self.component = component