Ejemplo n.º 1
0
 def test_multitask_two_modules(self):
     """Accept a different representation for each task"""
     edges = []
     cards = [2, 2]
     tg = TaskGraph(cards, edges)
     em = MTEndModel(
         layer_out_dims=[2, 8, 4],
         task_graph=tg,
         seed=1,
         verbose=False,
         input_modules=[IdentityModule(),
                        IdentityModule()],
         task_head_layers="top",
     )
     Xs = []
     for i, X in enumerate(self.Xs):
         Xs.append([X[:, 0], X[:, 1]])
     em.train_model(
         (Xs[0], self.Ys[0]),
         valid_data=(Xs[1], self.Ys[1]),
         verbose=False,
         n_epochs=10,
         checkpoint=False,
     )
     score = em.score((Xs[2], self.Ys[2]), reduce="mean", verbose=False)
     self.assertGreater(score, 0.95)
Ejemplo n.º 2
0
 def test_multitask_top(self):
     """Attach all task heads to the top layer"""
     edges = []
     cards = [2, 2]
     tg = TaskGraph(cards, edges)
     em = MTEndModel(
         layer_out_dims=[2, 8, 4],
         task_graph=tg,
         seed=1,
         verbose=False,
         task_head_layers="top",
     )
     top_layer = len(em.config["layer_out_dims"]) - 1
     self.assertEqual(len(em.task_map[top_layer]), em.t)
     em.train_model(
         (self.Xs[0], self.Ys[0]),
         valid_data=(self.Xs[1], self.Ys[1]),
         verbose=False,
         n_epochs=10,
         checkpoint=False,
     )
     score = em.score((self.Xs[2], self.Ys[2]),
                      reduce="mean",
                      verbose=False)
     self.assertGreater(score, 0.95)
 def test_binary_tree(self):
     cardinalities = [2, 2, 2]
     edges = [(0, 1), (0, 2)]
     tg = TaskGraph(cardinalities, edges)
     self.assertTrue(tg.parents[0] == [])
     self.assertTrue(tg.parents[1] == [0])
     self.assertTrue(tg.parents[2] == [0])
     self.assertTrue(tg.children[0] == [1, 2])
     self.assertTrue(tg.children[1] == [])
     self.assertTrue(tg.children[2] == [])
Ejemplo n.º 4
0
    def __init__(self, K=None, task_graph=None, **kwargs):
        """
        Args:
            K: A t-length list of task cardinalities (overrided by task_graph
                if task_graph is not None)
            task_graph: TaskGraph: A TaskGraph which defines a feasible set of
                task label vectors; overrides K if provided
        """
        config = recursive_merge_dicts(lm_default_config, kwargs)
        MTClassifier.__init__(self, K, config)

        if task_graph is None:
            task_graph = TaskGraph(K)
        self.task_graph = task_graph

        # Note: While K is a list of the cardinalities of the tasks, k is the
        # cardinality of the feasible set. These are always the same for a
        # single-task model, but rarely the same for a multi-task model.
        self.k = self.task_graph.k
Ejemplo n.º 5
0
    def __init__(
        self,
        layer_out_dims,
        input_modules=None,
        middle_modules=None,
        head_modules=None,
        K=[],
        task_graph=None,
        **kwargs,
    ):

        kwargs["layer_out_dims"] = layer_out_dims
        # kwargs["input_modules"] = input_modules
        # kwargs["middle_modules"] = middle_modules
        # kwargs["head_modules"] = head_modules

        config = recursive_merge_dicts(em_default_config,
                                       mt_em_default_config,
                                       misses="insert")
        config = recursive_merge_dicts(config, kwargs)
        MTClassifier.__init__(self, K, config)

        if task_graph is None:
            if len(K) == 0:
                raise ValueError("You must supply either a list of "
                                 "cardinalities (K) or a TaskGraph.")
            task_graph = TaskGraph(K)
        self.task_graph = task_graph
        self.K = self.task_graph.K  # Cardinalities by task
        self.t = self.task_graph.t  # Total number of tasks
        assert len(self.K) == self.t

        self._build(input_modules, middle_modules, head_modules)

        # Show network
        if self.config["verbose"]:
            print("\nNetwork architecture:")
            self._print()
            print()
Ejemplo n.º 6
0
 def test_multitask_custom_heads(self):
     """Accept a different representation for each task"""
     edges = []
     cards = [2, 2]
     tg = TaskGraph(cards, edges)
     em = MTEndModel(
         layer_out_dims=[2, 8, 4],
         task_graph=tg,
         seed=1,
         verbose=False,
         head_modules=[nn.Linear(8, 2), nn.Linear(4, 2)],
         task_head_layers=[1, 2],
     )
     em.train_model(
         (self.Xs[0], self.Ys[0]),
         valid_data=(self.Xs[1], self.Ys[1]),
         verbose=False,
         n_epochs=10,
         checkpoint=False,
     )
     score = em.score((self.Xs[2], self.Ys[2]),
                      reduce="mean",
                      verbose=False)
     self.assertGreater(score, 0.95)
 def test_unbalanced_tree(self):
     cardinalities = [2, 2, 2]
     edges = [(0, 1), (1, 2)]
     tg = TaskGraph(cardinalities, edges)
     self.assertTrue(tg.parents[1] == [0])
     self.assertTrue(tg.children[1] == [2])
 def test_binary_tree_depth3(self):
     cardinalities = [2, 2, 2, 2, 2, 2, 2]
     edges = [(0, 1), (0, 2), (1, 3), (1, 4), (2, 5), (2, 6)]
     tg = TaskGraph(cardinalities, edges)
     self.assertTrue(tg.parents[1] == [0])
     self.assertTrue(tg.children[1] == [3, 4])