Ejemplo n.º 1
0
 def _create_dataset(self, *data):
     return MetalDataset(*data)
Ejemplo n.º 2
0
    def train_model(
        self,
        L_train,
        Y_dev=None,
        deps=[],
        class_balance=None,
        log_writer=None,
        **kwargs,
    ):
        """Train the model (i.e. estimate mu) in one of two ways, depending on
        whether source dependencies are provided or not:

        Args:
            L_train: An [n,m] scipy.sparse matrix with values in {0,1,...,k}
                corresponding to labels from supervision sources on the
                training set
            Y_dev: Target labels for the dev set, for estimating class_balance
            deps: (list of tuples) known dependencies between supervision
                sources. If not provided, sources are assumed to be independent.
                TODO: add automatic dependency-learning code
            class_balance: (np.array) each class's percentage of the population

        (1) No dependencies (conditionally independent sources): Estimate mu
        subject to constraints:
            (1a) O_{B(i,j)} - (mu P mu.T)_{B(i,j)} = 0, for i != j, where B(i,j)
                is the block of entries corresponding to sources i,j
            (1b) np.sum( mu P, 1 ) = diag(O)

        (2) Source dependencies:
            - First, estimate Z subject to the inverse form
            constraint:
                (2a) O_\Omega + (ZZ.T)_\Omega = 0, \Omega is the deps mask
            - Then, compute Q = mu P mu.T
            - Finally, estimate mu subject to mu P mu.T = Q and (1b)
        """
        self.config = recursive_merge_dicts(self.config,
                                            kwargs,
                                            misses="ignore")
        train_config = self.config["train_config"]

        # TODO: Implement logging for label model?
        if log_writer is not None:
            raise NotImplementedError("Logging for LabelModel.")

        # Note that the LabelModel class implements its own (centered) L2 reg.
        l2 = train_config.get("l2", 0)

        self._set_class_balance(class_balance, Y_dev)
        self._set_constants(L_train)
        self._set_dependencies(deps)
        self._check_L(L_train)

        # Whether to take the simple conditionally independent approach, or the
        # "inverse form" approach for handling dependencies
        # This flag allows us to eg test the latter even with no deps present
        self.inv_form = len(self.deps) > 0

        # Creating this faux dataset is necessary for now because the LabelModel
        # loss functions do not accept inputs, but Classifer._train_model()
        # expects training data to feed to the loss functions.
        dataset = MetalDataset([0], [0])
        train_loader = DataLoader(dataset)
        if self.inv_form:
            # Compute O, O^{-1}, and initialize params
            if self.config["verbose"]:
                print("Computing O^{-1}...")
            self._generate_O_inv(L_train)
            self._init_params()

            # Estimate Z, compute Q = \mu P \mu^T
            if self.config["verbose"]:
                print("Estimating Z...")
            self._train_model(train_loader, self.loss_inv_Z)
            self.Q = torch.from_numpy(self.get_Q()).float()

            # Estimate \mu
            if self.config["verbose"]:
                print("Estimating \mu...")
            self._train_model(train_loader, partial(self.loss_inv_mu, l2=l2))
        else:
            # Compute O and initialize params
            if self.config["verbose"]:
                print("Computing O...")
            self._generate_O(L_train)
            self._init_params()

            # Estimate \mu
            if self.config["verbose"]:
                print("Estimating \mu...")
            self._train_model(train_loader, partial(self.loss_mu, l2=l2))
Ejemplo n.º 3
0
 def _make_data_loader(self, X, Y, data_loader_config):
     dataset = MetalDataset(X, self._preprocess_Y(Y, self.k))
     data_loader = DataLoader(dataset, shuffle=True, **data_loader_config)
     return data_loader