示例#1
0
    def _build(self, input_module, middle_modules, head_module):
        """
        TBD
        """
        input_layer = self._build_input_layer(input_module)
        middle_layers = self._build_middle_layers(middle_modules)

        # Construct list of layers
        layers = [input_layer]
        if middle_layers is not None:
            layers += middle_layers
        if not self.config["skip_head"]:
            head = self._build_task_head(head_module)
            layers.append(head)

        # Construct network
        if len(layers) > 1:
            self.network = nn.Sequential(*layers)
        else:
            self.network = layers[0]

        # Construct loss module
        loss_weights = self.config["train_config"]["loss_weights"]
        if loss_weights is not None and self.config["verbose"]:
            print(f"Using class weight vector {loss_weights}...")
        reduction = self.config["train_config"]["loss_fn_reduction"]
        self.criteria = SoftCrossEntropyLoss(
            weight=self._to_torch(loss_weights, dtype=torch.FloatTensor),
            reduction=reduction,
        )
示例#2
0
    def test_sce_equals_ce(self):
        # All correct predictions
        Y = torch.tensor([1, 2, 3], dtype=torch.long)
        Y_s = pred_to_prob(Y, k=4).float()

        sce = SoftCrossEntropyLoss(reduction="none")
        ce = nn.CrossEntropyLoss(reduction="none")
        for _ in range(10):
            Y_ps = torch.rand_like(Y_s)
            Y_ps = Y_ps / Y_ps.sum(dim=1).reshape(-1, 1)
            self.assertTrue((sce(Y_ps, Y_s) == ce(Y_ps, Y - 1)).all())

        sce = SoftCrossEntropyLoss(reduction="sum")
        ce = nn.CrossEntropyLoss(reduction="sum")
        for _ in range(10):
            Y_ps = torch.rand_like(Y_s)
            Y_ps = Y_ps / Y_ps.sum(dim=1).reshape(-1, 1)
            self.assertAlmostEqual(sce(Y_ps, Y_s).numpy(),
                                   ce(Y_ps, Y - 1).numpy(),
                                   places=5)

        sce = SoftCrossEntropyLoss(reduction="mean")
        ce = nn.CrossEntropyLoss(reduction="mean")
        for _ in range(10):
            Y_ps = torch.rand_like(Y_s)
            Y_ps = Y_ps / Y_ps.sum(dim=1).reshape(-1, 1)
            self.assertAlmostEqual(sce(Y_ps, Y_s).numpy(),
                                   ce(Y_ps, Y - 1).numpy(),
                                   places=5)
示例#3
0
    def _build(self, input_modules, middle_modules, head_modules):
        """
        TBD
        """
        self.input_layer = self._build_input_layer(input_modules)
        self.middle_layers = self._build_middle_layers(middle_modules)
        self.heads = self._build_task_heads(head_modules)

        # Construct loss module
        self.criteria = SoftCrossEntropyLoss(reduction="sum")
示例#4
0
    def _build(self, input_modules, middle_modules, head_modules):
        """
        TBD
        """
        self.input_layer = self._build_input_layer(input_modules)
        self.middle_layers = self._build_middle_layers(middle_modules)
        self.heads = self._build_task_heads(head_modules)

        # Construct loss module
        reduction = self.config["train_config"]["loss_fn_reduction"]
        self.criteria = SoftCrossEntropyLoss(reduction=reduction)
示例#5
0
 def test_prob_labels(self):
     Y_s = torch.tensor([[0.1, 0.9], [0.5, 0.5]])
     Y_ps1 = torch.tensor([[0.1, 0.2], [1.0, 0.0]])
     Y_ps2 = torch.tensor([[0.1, 0.3], [1.0, 0.0]])
     Y_ps3 = torch.tensor([[0.1, 0.3], [0.0, 1.0]])
     sce = SoftCrossEntropyLoss()
     self.assertLess(sce(Y_ps2, Y_s), sce(Y_ps1, Y_s))
     self.assertEqual(sce(Y_ps2, Y_s), sce(Y_ps3, Y_s))
示例#6
0
    def test_loss_weights(self):
        # All incorrect predictions
        Y = torch.tensor([1, 1, 2], dtype=torch.long)
        Y_s = pred_to_prob(Y, k=3)
        Y_ps = torch.tensor([[-100.0, 100.0, -100.0], [-100.0, 100.0, -100.0],
                             [-100.0, 100.0, -100.0]])
        weight1 = torch.tensor([1, 2, 1], dtype=torch.float)
        weight2 = torch.tensor([10, 20, 10], dtype=torch.float)
        ce1 = nn.CrossEntropyLoss(weight=weight1, reduction="none")
        sce1 = SoftCrossEntropyLoss(weight=weight1)
        sce2 = SoftCrossEntropyLoss(weight=weight2)

        self.assertAlmostEqual(float(ce1(Y_ps, Y - 1).mean()),
                               float(sce1(Y_ps, Y_s)),
                               places=3)
        self.assertAlmostEqual(float(sce1(Y_ps, Y_s)) * 10,
                               float(sce2(Y_ps, Y_s)),
                               places=3)
示例#7
0
    def test_perfect_predictions(self):
        Y = torch.tensor([1, 2, 3], dtype=torch.long)
        Y_s = pred_to_prob(Y, k=4)

        sce = SoftCrossEntropyLoss()
        # Guess nearly perfectly
        Y_ps = Y_s.clone().float()
        Y_ps[Y_ps == 1] = 100
        Y_ps[Y_ps == 0] = -100
        self.assertAlmostEqual(sce(Y_ps, Y_s).numpy(), 0)
示例#8
0
    def test_perfect_predictions(self):
        Y_h = torch.tensor([1, 2, 3], dtype=torch.long)
        target = Y_h
        Y = hard_to_soft(Y_h, k=4)

        sce = SoftCrossEntropyLoss()
        # Guess nearly perfectly
        Y_p = Y.clone()
        Y_p[Y_p == 1] = 100
        Y_p[Y_p == 0] = -100
        self.assertAlmostEqual(sce(Y_p, Y).numpy(), 0)
示例#9
0
    def test_loss_weights(self):
        # All incorrect predictions
        Y_h = torch.tensor([1,1,2], dtype=torch.long)
        target = Y_h
        K_t = 3
        Y = hard_to_soft(Y_h, k=K_t)
        Y_p = torch.tensor([
            [0., -100.,  100., -100.],
            [0., -100.,  100., -100.],
            [0., -100.,  100., -100.],
        ])
        weight1 = torch.tensor([0,1,2,1], dtype=torch.float)
        weight2 = torch.tensor([0,10,20,10], dtype=torch.float)
        ce1 = nn.CrossEntropyLoss(weight=weight1, reduction='none')
        sce1 = SoftCrossEntropyLoss(weight=weight1)
        sce2 = SoftCrossEntropyLoss(weight=weight2)

        self.assertAlmostEqual(
            float(ce1(Y_p, target).mean()), float(sce1(Y_p, Y)), places=3)
        self.assertAlmostEqual(
            float(sce1(Y_p, Y)) * 10, float(sce2(Y_p, Y)), places=3)
示例#10
0
    def _build(self, input_module, middle_modules, head_module):
        """
        TBD
        """
        input_layer = self._build_input_layer(input_module)
        middle_layers = self._build_middle_layers(middle_modules)
        head = self._build_task_head(head_module)
        if middle_layers is None:
            self.network = nn.Sequential(input_layer, head)
        else:
            self.network = nn.Sequential(input_layer, *middle_layers, head)

        # Construct loss module
        self.criteria = SoftCrossEntropyLoss(reduction="sum")
示例#11
0
    def test_sce_equals_ce(self):
        # All correct predictions
        Y_h = torch.tensor([1, 2, 3], dtype=torch.long)
        target = Y_h
        Y = hard_to_soft(Y_h, k=4)
        
        sce = SoftCrossEntropyLoss(reduction='none')
        ce = nn.CrossEntropyLoss(reduction='none')
        for _ in range(10):
            Y_p = torch.randn(Y.shape)
            self.assertTrue((sce(Y_p, Y) == ce(Y_p, target)).all())

        sce = SoftCrossEntropyLoss(reduction='sum')
        ce = nn.CrossEntropyLoss(reduction='sum')
        for _ in range(10):
            self.assertAlmostEqual(sce(Y_p, Y).numpy(), ce(Y_p, target).numpy(),
                places=5)

        sce = SoftCrossEntropyLoss(reduction='elementwise_mean') # default
        ce = nn.CrossEntropyLoss(reduction='elementwise_mean')
        for _ in range(10):
            self.assertAlmostEqual(sce(Y_p, Y).numpy(), ce(Y_p, target).numpy(),
            places=5)
示例#12
0
    def _build(self, input_module, middle_modules, head_module):
        """
        TBD
        """
        input_layer = self._build_input_layer(input_module)
        middle_layers = self._build_middle_layers(middle_modules)
        head = self._build_task_head(head_module)

        # Construct list of layers
        layers = [input_layer]
        if middle_layers is not None:
            layers += middle_layers
        if not self.config["skip_head"]:
            layers.append(head)

        # Construct network
        if len(layers) > 1:
            self.network = nn.Sequential(*layers)
        else:
            self.network = layers[0]

        # Construct loss module
        self.criteria = SoftCrossEntropyLoss(reduction="sum")
示例#13
0
 def test_soft_labels(self):
     Y = torch.tensor([
         [0.1, 0.9],
         [0.5, 0.5],
     ])
     Y_p1 = torch.tensor([
         [0.1, 0.2],
         [1.0, 0.0],
     ])
     Y_p2 = torch.tensor([
         [0.1, 0.3],
         [1.0, 0.0],
     ])
     Y_p3 = torch.tensor([
         [0.1, 0.3],
         [0.0, 1.0],
     ])
     sce = SoftCrossEntropyLoss()
     self.assertLess(sce(Y_p2, Y), sce(Y_p1, Y))
     self.assertEqual(sce(Y_p2, Y), sce(Y_p3, Y))
示例#14
0
文件: end_model.py 项目: seast/metal
    def _build(self, input_module):
        """
        TBD
        """
        # The number of layers is inferred from the specified layer_output_dims
        layer_dims = self.config['layer_output_dims']
        num_layers = len(layer_dims)

        if not input_module.get_output_dim() == layer_dims[0]:
            msg = (f"Input module output size != the first layer output size: "
                f"({input_module.get_output_dim()} != {layer_dims[0]})")
            raise ValueError(msg)

        # Set dropout probabilities for all layers
        dropout = self.config['dropout']
        if isinstance(dropout, float):
            dropouts = [dropout] * num_layers
        elif isinstance(dropout, list):
            dropouts = dropout
        
        # Construct layers
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            layer = []
            # Input module or Linear
            if i == 0:
                layer.append(input_module)
            else:
                layer.append(nn.Linear(*layer_dims[i-1:i+1]))
            if not isinstance(input_module, IdentityModule):
                layer.append(nn.ReLU())
            if self.config['batchnorm']:
                layer.append(nn.BatchNorm1d(layer_dims[i]))
            if self.config['dropout']:
                layer.append(nn.Dropout(dropouts[i]))
            self.layers.add_module(f'layer{i}', nn.Sequential(*layer))

        self._attach_task_heads(num_layers)

        # Construct loss module
        self.criteria = SoftCrossEntropyLoss()
示例#15
0
class EndModel(Classifier):
    """A dynamically constructed discriminative classifier

        layer_out_dims: a list of integers corresponding to the output sizes
            of the layers of your network. The first element is the
            dimensionality of the input layer, the last element is the
            dimensionality of the head layer (equal to the cardinality of the
            task), and all other elements dictate the sizes of middle layers.
            The number of middle layers will be inferred from this list.
        input_module: (nn.Module) a module that converts the user-provided
            model inputs to torch.Tensors. Defaults to IdentityModule.
        middle_modules: (nn.Module) a list of modules to execute between the
            input_module and task head. Defaults to nn.Linear.
        head_module: (nn.Module) a module to execute right before the final
            softmax that outputs a prediction for the task.
    """
    def __init__(
        self,
        layer_out_dims,
        input_module=None,
        middle_modules=None,
        head_module=None,
        **kwargs,
    ):

        if len(layer_out_dims) < 2 and not kwargs["skip_head"]:
            raise ValueError(
                "Arg layer_out_dims must have at least two "
                "elements corresponding to the output dim of the input module "
                "and the cardinality of the task. If the input module is the "
                "IdentityModule, then the output dim of the input module will "
                "be equal to the dimensionality of your input data points")

        # Add layer_out_dims to kwargs so it will be merged into the config dict
        kwargs["layer_out_dims"] = layer_out_dims
        config = recursive_merge_dicts(em_default_config,
                                       kwargs,
                                       misses="insert")
        super().__init__(k=layer_out_dims[-1], config=config)

        self._build(input_module, middle_modules, head_module)

        # Show network
        if self.config["verbose"]:
            print("\nNetwork architecture:")
            self._print()
            print()

    def _build(self, input_module, middle_modules, head_module):
        """
        TBD
        """
        input_layer = self._build_input_layer(input_module)
        middle_layers = self._build_middle_layers(middle_modules)

        # Construct list of layers
        layers = [input_layer]
        if middle_layers is not None:
            layers += middle_layers
        if not self.config["skip_head"]:
            head = self._build_task_head(head_module)
            layers.append(head)

        # Construct network
        if len(layers) > 1:
            self.network = nn.Sequential(*layers)
        else:
            self.network = layers[0]

        # Construct loss module
        loss_weights = self.config["train_config"]["loss_weights"]
        if loss_weights is not None and self.config["verbose"]:
            print(f"Using class weight vector {loss_weights}...")
        reduction = self.config["train_config"]["loss_fn_reduction"]
        self.criteria = SoftCrossEntropyLoss(
            weight=self._to_torch(loss_weights, dtype=torch.FloatTensor),
            reduction=reduction,
        )

    def _build_input_layer(self, input_module):
        if input_module is None:
            input_module = IdentityModule()
        output_dim = self.config["layer_out_dims"][0]
        input_layer = self._make_layer(
            input_module,
            "input",
            self.config["input_layer_config"],
            output_dim=output_dim,
        )
        return input_layer

    def _build_middle_layers(self, middle_modules):
        layer_out_dims = self.config["layer_out_dims"]
        num_mid_layers = len(layer_out_dims) - 2
        if num_mid_layers == 0:
            return None

        middle_layers = nn.ModuleList()
        for i in range(num_mid_layers):
            if middle_modules is None:
                module = nn.Linear(*layer_out_dims[i:i + 2])
                output_dim = layer_out_dims[i + 1]
            else:
                module = middle_modules[i]
                output_dim = None
            layer = self._make_layer(
                module,
                "middle",
                self.config["middle_layer_config"],
                output_dim=output_dim,
            )
            middle_layers.add_module(f"layer{i+1}", layer)
        return middle_layers

    def _build_task_head(self, head_module):
        if head_module is None:
            head = nn.Linear(self.config["layer_out_dims"][-2], self.k)
        else:
            # Note that if head module is provided, it must have input dim of
            # the last middle module and output dim of self.k, the cardinality
            head = head_module
        return head

    def _make_layer(self, module, prefix, layer_config, output_dim=None):
        if isinstance(module, IdentityModule):
            return module
        layer = [module]
        if layer_config[f"{prefix}_relu"]:
            layer.append(nn.ReLU())
        if layer_config[f"{prefix}_batchnorm"] and output_dim:
            layer.append(nn.BatchNorm1d(output_dim))
        if layer_config[f"{prefix}_dropout"]:
            layer.append(nn.Dropout(layer_config[f"{prefix}_dropout"]))
        if len(layer) > 1:
            return nn.Sequential(*layer)
        else:
            return layer[0]

    def _print(self):
        print(self.network)

    def forward(self, x):
        """Returns a list of outputs for tasks 0,...t-1

        Args:
            x: a [batch_size, ...] batch from X
        """
        return self.network(x)

    @staticmethod
    def _reset_module(m):
        """A method for resetting the parameters of any module in the network

        First, handle special cases (unique initialization or none required)
        Next, use built in method if available
        Last, report that no initialization occured to avoid silent failure.

        This will be called on all children of m as well, so do not recurse
        manually.
        """
        if callable(getattr(m, "reset_parameters", None)):
            m.reset_parameters()

    def update_config(self, update_dict):
        """Updates self.config with the values in a given update dictionary"""
        self.config = recursive_merge_dicts(self.config, update_dict)

    def _preprocess_Y(self, Y, k):
        """Convert Y to prob labels if necessary"""
        Y = Y.clone()

        # If preds, convert to probs
        if Y.dim() == 1 or Y.shape[1] == 1:
            Y = pred_to_prob(Y.long(), k=k)
        return Y

    def _create_dataset(self, *data):
        return MetalDataset(*data)

    def _get_loss_fn(self):
        criteria = self.criteria.to(self.config["device"])
        # This self.preprocess_Y allows us to not handle preprocessing
        # in a custom dataloader, but decreases speed a bit
        loss_fn = lambda X, Y: criteria(self.forward(X),
                                        self._preprocess_Y(Y, self.k))
        return loss_fn

    def train_model(self,
                    train_data,
                    valid_data=None,
                    log_writer=None,
                    **kwargs):
        self.config = recursive_merge_dicts(self.config, kwargs)

        # If train_data is provided as a tuple (X, Y), we can make sure Y is in
        # the correct format
        # NOTE: Better handling for if train_data is Dataset or DataLoader...?
        if isinstance(train_data, (tuple, list)):
            X, Y = train_data
            Y = self._preprocess_Y(self._to_torch(Y, dtype=torch.FloatTensor),
                                   self.k)
            train_data = (X, Y)

        # Convert input data to data loaders
        train_loader = self._create_data_loader(train_data, shuffle=True)

        # Create loss function
        loss_fn = self._get_loss_fn()

        # Execute training procedure
        self._train_model(train_loader,
                          loss_fn,
                          valid_data=valid_data,
                          log_writer=log_writer)

    def predict_proba(self, X):
        """Returns a [n, k] tensor of probs (probabilistic labels)."""
        return F.softmax(self.forward(X), dim=1).data.cpu().numpy()
示例#16
0
class MTEndModel(MTClassifier, EndModel):
    """A multi-task discriminative model.

    Note that when looking up methods, MTEndModel will first search in
    MTClassifier, followed by EndModel.

    Args:
        layer_out_dims: a list of integers corresponding to the output sizes
            of the layers of your network. The first element is the
            dimensionality of the input layer, and all other elements dictate
            the sizes of middle layers. The number of middle layers will be
            inferred from this list. The output dimensions of the task heads
            will be inferred from the cardinalities pulled from K or the
            task_graph.
        input_modules: (nn.Module) a list of modules that converts the
            user-provided model inputs to torch.Tensors.
            Defaults to IdentityModule.
        middle_modules: (nn.Module) a list of modules to execute between the
            input_module and task head. Defaults to nn.Linear.
        head_module: (nn.Module) a module to execute right before the final
            softmax that outputs a prediction for the task.
        K: A t-length list of task cardinalities (overrided by task_graph
            if task_graph is not None)
        task_graph: TaskGraph: A TaskGraph which defines a feasible set of
            task label vectors; overrides K
    """
    def __init__(
        self,
        layer_out_dims,
        input_modules=None,
        middle_modules=None,
        head_modules=None,
        K=[],
        task_graph=None,
        **kwargs,
    ):

        kwargs["layer_out_dims"] = layer_out_dims
        kwargs["input_modules"] = input_modules
        kwargs["middle_modules"] = middle_modules
        kwargs["head_modules"] = head_modules

        config = recursive_merge_dicts(em_default_config,
                                       mt_em_default_config,
                                       misses="insert")
        config = recursive_merge_dicts(config, kwargs)
        MTClassifier.__init__(self, K, config)

        if task_graph is None:
            if K is None:
                raise ValueError("You must supply either a list of "
                                 "cardinalities (K) or a TaskGraph.")
            task_graph = TaskGraph(K)
        self.task_graph = task_graph
        self.K = self.task_graph.K  # Cardinalities by task
        self.t = self.task_graph.t  # Total number of tasks
        assert len(self.K) == self.t

        self._build(input_modules, middle_modules, head_modules)

        # Show network
        if self.config["verbose"]:
            print("\nNetwork architecture:")
            self._print()
            print()

    def _build(self, input_modules, middle_modules, head_modules):
        """
        TBD
        """
        self.input_layer = self._build_input_layer(input_modules)
        self.middle_layers = self._build_middle_layers(middle_modules)
        self.heads = self._build_task_heads(head_modules)

        # Construct loss module
        self.criteria = SoftCrossEntropyLoss(reduction="sum")

    def _build_input_layer(self, input_modules):
        if input_modules is None:
            output_dim = self.config["layer_out_dims"][0]
            input_modules = IdentityModule()

        if isinstance(input_modules, list):
            input_layer = [
                self._make_layer(mod, "input",
                                 self.config["input_layer_config"])
                for mod in input_modules
            ]
        else:
            input_layer = self._make_layer(
                input_modules,
                "input",
                self.config["input_layer_config"],
                output_dim=output_dim,
            )

        return input_layer

    def _build_middle_layers(self, middle_modules):
        layer_out_dims = self.config["layer_out_dims"]
        num_mid_layers = len(layer_out_dims) - 1
        if num_mid_layers == 0:
            return None

        middle_layers = nn.ModuleList()
        for i in range(num_mid_layers):
            if middle_modules is None:
                module = nn.Linear(*layer_out_dims[i:i + 2])
                layer = self._make_layer(
                    module,
                    "middle",
                    self.config["middle_layer_config"],
                    output_dim=layer_out_dims[i + 1],
                )
            else:
                module = middle_modules[i]
                layer = self._make_layer(module, "middle",
                                         self.config["middle_layer_config"])
            middle_layers.add_module(f"layer{i+1}", layer)
        return middle_layers

    def _build_task_heads(self, head_modules):
        """Creates and attaches task_heads to the appropriate network layers"""
        # Make task head layer assignments
        num_layers = len(self.config["layer_out_dims"])
        task_head_layers = self._set_task_head_layers(num_layers)

        # task_head_layers stores the layer whose output is input to task head t
        # task_map stores the task heads that appear at each layer
        self.task_map = defaultdict(list)
        for t, l in enumerate(task_head_layers):
            self.task_map[l].append(t)

        if any(l == 0 for l in task_head_layers) and head_modules is None:
            raise Exception(
                "If any task head is being attached to layer 0 "
                "(the input modules), then you must provide a t-length list of "
                "head_modules, since the output dimension of each input_module "
                "cannot be inferred.")

        # Construct heads
        head_dims = [self.K[t] for t in range(self.t)]

        heads = nn.ModuleList()
        for t in range(self.t):
            input_dim = self.config["layer_out_dims"][task_head_layers[t]]
            if self.config["pass_predictions"]:
                for p in self.task_graph.parents[t]:
                    input_dim += head_dims[p]
            output_dim = head_dims[t]

            if head_modules is None:
                head = nn.Linear(input_dim, output_dim)
            elif isinstance(head_modules, list):
                head = head_modules[t]
            else:
                head = copy.deepcopy(head_modules)
            heads.append(head)
        return heads

    def _set_task_head_layers(self, num_layers):
        head_layers = self.config["task_head_layers"]
        if isinstance(head_layers, list):
            task_head_layers = head_layers
        elif head_layers == "top":
            task_head_layers = [num_layers - 1] * self.t
        else:
            msg = f"Invalid option to 'head_layers' parameter: '{head_layers}'"
            raise ValueError(msg)

        # Confirm that the network does not extend beyond the latest task head
        if max(task_head_layers) < num_layers - 1:
            unused = num_layers - 1 - max(task_head_layers)
            msg = (f"The last {unused} layer(s) of your network have no task "
                   "heads attached to them")
            raise ValueError(msg)

        # Confirm that parents come b/f children if predictions are passed
        # between tasks
        if self.config["pass_predictions"]:
            for t, l in enumerate(task_head_layers):
                for p in self.task_graph.parents[t]:
                    if task_head_layers[p] >= l:
                        p_layer = task_head_layers[p]
                        msg = (
                            f"Task {t}'s layer ({l}) must be larger than its "
                            f"parent task {p}'s layer ({p_layer})")
                        raise ValueError(msg)

        return task_head_layers

    def _print(self):
        print("\n--Input Layer--")
        if isinstance(self.input_layer, list):
            for mod in self.input_layer:
                print(mod)
        else:
            print(self.input_layer)

        for t in self.task_map[0]:
            print(f"(head{t})")
            print(self.heads[t])

        print("\n--Middle Layers--")
        for i, layer in enumerate(self.middle_layers, start=1):
            print(f"(layer{i}):")
            print(layer)
            for t in self.task_map[i]:
                print(f"(head{t})")
                print(self.heads[t])
            print()

    def forward(self, x):
        """Returns a list of outputs for tasks 0,...t-1

        Args:
            x: a [batch_size, ...] batch from X
        """
        head_outputs = [None] * self.t

        # Execute input layer
        if isinstance(self.input_layer, list):  # One input_module per task
            input_outputs = [mod(x) for mod, x in zip(self.input_layer, x)]
            x = torch.stack(input_outputs, dim=1)

            # Execute level-0 task heads from their respective input modules
            for t in self.task_map[0]:
                head = self.heads[t]
                head_outputs[t] = head(input_outputs[t])
        else:  # One input_module for all tasks
            x = self.input_layer(x)

            # Execute level-0 task heads from the single input module
            for t in self.task_map[0]:
                head = self.heads[t]
                head_outputs[t] = head(x)

        # Execute middle layers
        for i, layer in enumerate(self.middle_layers, start=1):
            x = layer(x)

            # Attach level-i task heads from the ith middle module
            for t in self.task_map[i]:
                head = self.heads[t]
                # Optionally include as input the predictions of parent tasks
                if self.config["pass_predictions"] and bool(
                        self.task_graph.parents[t]):
                    task_input = [x]
                    for p in self.task_graph.parents[t]:
                        task_input.append(head_outputs[p])
                    task_input = torch.stack(task_input, dim=1)
                else:
                    task_input = x
                head_outputs[t] = head(task_input)
        return head_outputs

    def _preprocess_Y(self, Y, k=None):
        """Convert Y to t-length list of soft labels if necessary"""
        # If not a list, convert to a singleton list
        if not isinstance(Y, list):
            if self.t != 1:
                msg = "For t > 1, Y must be a list of n-dim or [n, K_t] tensors"
                raise ValueError(msg)
            Y = [Y]

        if not len(Y) == self.t:
            msg = f"Expected Y to be a t-length list (t={self.t}), not {len(Y)}"
            raise ValueError(msg)

        return [
            EndModel._preprocess_Y(self, Y_t, self.K[t])
            for t, Y_t in enumerate(Y)
        ]

    def _get_loss_fn(self):
        """Returns the loss function to use in the train_model routine"""
        if self.config["use_cuda"]:
            criteria = self.criteria.cuda()
        else:
            criteria = self.criteria
        loss_fn = lambda X, Y: sum(
            criteria(Y_tp, Y_t) for Y_tp, Y_t in zip(self.forward(X), Y))
        return loss_fn

    def predict_proba(self, X):
        """Returns a list of t [n, K_t] tensors of soft (float) predictions."""
        return [
            F.softmax(output, dim=1).data.cpu().numpy()
            for output in self.forward(X)
        ]

    def predict_task_proba(self, X, t):
        """Returns an n x k matrix of probabilities for each label of task t"""
        return self.predict_tasks_proba(X)[t]