コード例 #1
0
    def _get_template_values(self, replicator):
        """Returns the values which which will be passed to the replicated
        protocols, evaluating any protocol paths to retrieve the referenced
        values.

        Parameters
        ----------
        replicator: ProtocolReplicator
            The replictor which is replicating the protocols.

        Returns
        -------
        Any
            The template values.
        """

        invalid_value_error = ValueError(
            f"Template values must either be a constant or come "
            f"from the global scope (and not from {replicator.template_values})"
        )

        # Get the list of values which will be passed to the newly created protocols.
        if isinstance(replicator.template_values, ProtocolPath):

            if not replicator.template_values.is_global:
                raise invalid_value_error

            return get_nested_attribute(
                self._global_metadata,
                replicator.template_values.property_name)

        elif not isinstance(replicator.template_values, list):
            raise NotImplementedError()

        evaluated_template_values = []

        for template_value in replicator.template_values:

            if not isinstance(template_value, ProtocolPath):

                evaluated_template_values.append(template_value)
                continue

            if not template_value.is_global:
                raise invalid_value_error

            evaluated_template_values.append(
                get_nested_attribute(self._global_metadata,
                                     template_value.property_name))

        return evaluated_template_values
コード例 #2
0
    def _build_protocols(self, schema):
        """Creates a set of protocols based on a WorkflowSchema.

        Parameters
        ----------
        schema: WorkflowSchema
            The schema to use when creating the protocols
        """
        self._apply_replicators(schema)

        for protocol_schema in schema.protocol_schemas:

            protocol = protocol_schema.to_protocol()

            # Try to set global properties on each of the protocols
            for input_path in protocol.required_inputs:

                value_references = protocol.get_value_references(input_path)

                for source_path, value_reference in value_references.items():

                    if not value_reference.is_global:
                        continue

                    value = get_nested_attribute(self._global_metadata,
                                                 value_reference.property_name)
                    protocol.set_value(source_path, value)

            protocol.set_uuid(self.uuid)
            self._protocols.append(protocol)
コード例 #3
0
    def _build_output_to_store(self, output_to_store):
        """Sets the inputs of a `BaseStoredData` object which
        are taken from the global metadata.

        Parameters
        ----------
        output_to_store: BaseStoredData
            The output to set the inputs of.

        Returns
        -------
        BaseStoredData
            The built object with all of its inputs correctly set.
        """

        for attribute_name in output_to_store.get_attributes(StorageAttribute):

            attribute_value = getattr(output_to_store, attribute_name)

            if (not isinstance(attribute_value, ProtocolPath)
                    or not attribute_value.is_global):
                continue

            attribute_value = get_nested_attribute(
                self._global_metadata, attribute_value.property_name)
            setattr(output_to_store, attribute_name, attribute_value)

        return output_to_store
コード例 #4
0
def test_get_nested_attribute():

    dummy_object = DummyNestedClass()
    dummy_object.object_a = "a"

    dummy_nested_object_a = DummyNestedClass()
    dummy_nested_object_a.object_a = 1
    dummy_nested_object_a.object_b = [0]

    dummy_nested_list_object_0 = DummyNestedClass()
    dummy_nested_list_object_0.object_a = "a"
    dummy_nested_list_object_0.object_b = "b"

    dummy_nested_object_b = DummyNestedClass()
    dummy_nested_object_b.object_a = 2
    dummy_nested_object_b.object_b = [dummy_nested_list_object_0]

    dummy_object.object_b = {
        "a": dummy_nested_object_a,
        "b": dummy_nested_object_b
    }

    assert get_nested_attribute(dummy_object, "object_a") == "a"

    assert get_nested_attribute(dummy_object, "object_b[a].object_a") == 1
    assert get_nested_attribute(dummy_object, "object_b[a].object_b[0]") == 0

    assert get_nested_attribute(dummy_object, "object_b[b].object_a") == 2
    assert get_nested_attribute(dummy_object,
                                "object_b[b].object_b[0].object_a") == "a"
    assert get_nested_attribute(dummy_object,
                                "object_b[b].object_b[0].object_b") == "b"
コード例 #5
0
    def _store_output_data(
        data_object_path,
        data_directory,
        output_to_store,
        results_by_id,
    ):
        """Collects all of the simulation to store, and saves it into a directory
        whose path will be passed to the storage backend to process.

        Parameters
        ----------
        data_object_path: str
            The file path to serialize the data object to.
        data_directory: str
            The path of the directory to store ancillary data in.
        output_to_store: BaseStoredData
            An object which contains `ProtocolPath`s pointing to the
            data to store.
        results_by_id: dict of ProtocolPath and any
            The results of the protocols which formed the property
            estimation workflow.
        """

        makedirs(data_directory, exist_ok=True)

        for attribute_name in output_to_store.get_attributes(StorageAttribute):

            attribute = getattr(output_to_store.__class__, attribute_name)
            attribute_value = getattr(output_to_store, attribute_name)

            if isinstance(attribute_value, ProtocolPath):

                # Strip any nested attribute accessors before retrieving the result
                property_name = attribute_value.property_name.split(
                    ".")[0].split("[")[0]

                result_path = ProtocolPath(property_name,
                                           *attribute_value.protocol_ids)
                result = results_by_id[result_path]

                if result_path != attribute_value:

                    result = get_nested_attribute(
                        {property_name: result}, attribute_value.property_name)

                attribute_value = result

                # Do not store gradient information for observables as this information
                # is very workflow / context specific.
                if isinstance(attribute_value,
                              (Observable, ObservableArray, ObservableFrame)):
                    attribute_value.clear_gradients()

            if issubclass(attribute.type_hint, FilePath):
                file_copy(attribute_value, data_directory)
                attribute_value = path.basename(attribute_value)

            setattr(output_to_store, attribute_name, attribute_value)

        with open(data_object_path, "w") as file:
            json.dump(output_to_store, file, cls=TypedJSONEncoder)