예제 #1
0
파일: mt_end_model.py 프로젝트: seast/metal
    def __init__(self,
                 task_graph=None,
                 input_module=None,
                 seed=None,
                 **kwargs):
        defaults = recursive_merge_dicts(em_default_config,
                                         mt_em_default_config,
                                         misses='insert')
        self.config = recursive_merge_dicts(defaults, kwargs)

        # If no task_graph is specified, default to a single binary task
        if task_graph is None:
            task_graph = TaskHierarchy(edges=[], cardinalities=[2])
        self.task_graph = task_graph
        self.K_t = self.task_graph.K_t  # Cardinalities by task
        self.T = self.task_graph.T  # Total number of tasks

        MTClassifier.__init__(self, cardinalities=self.K_t, seed=seed)

        if input_module is None:
            input_module = IdentityModule(self.config['layer_output_dims'][0])

        self._build(input_module)

        # Show network
        if self.config['verbose']:
            print("\nNetwork architecture:")
            self._print()
            print()
예제 #2
0
    def __init__(
        self,
        layer_out_dims,
        input_modules=None,
        middle_modules=None,
        head_modules=None,
        K=[],
        task_graph=None,
        **kwargs,
    ):
        kwargs["layer_out_dims"] = layer_out_dims
        config = recursive_merge_dicts(em_default_config,
                                       mt_em_default_config,
                                       misses="insert")
        config = recursive_merge_dicts(config, kwargs)
        MTClassifier.__init__(self, K, config)

        if task_graph is None:
            if K is None:
                raise ValueError("You must supply either a list of "
                                 "cardinalities (K) or a TaskGraph.")
            task_graph = TaskGraph(K)
        self.task_graph = task_graph
        self.K = self.task_graph.K  # Cardinalities by task
        self.t = self.task_graph.t  # Total number of tasks
        assert len(self.K) == self.t

        self._build(input_modules, middle_modules, head_modules)

        # Show network
        if self.config["verbose"]:
            print("\nNetwork architecture:")
            self._print()
            print()
예제 #3
0
    def __init__(self, K=None, task_graph=None, **kwargs):
        """
        Args:
            K: A t-length list of task cardinalities (overrided by task_graph
                if task_graph is not None)
            task_graph: TaskGraph: A TaskGraph which defines a feasible set of
                task label vectors; overrides K if provided
        """
        config = recursive_merge_dicts(lm_default_config, kwargs)
        MTClassifier.__init__(self, K, config)

        if task_graph is None:
            task_graph = TaskGraph(K)
        self.task_graph = task_graph

        # Note: While K is a list of the cardinalities of the tasks, k is the
        # cardinality of the feasible set. These are always the same for a
        # single-task model, but rarely the same for a multi-task model.
        self.k = self.task_graph.k