def fit(self,
            encoded_data: EncodedData,
            label: Label,
            cores_for_training: int = 2):
        self.feature_names = encoded_data.feature_names
        X = encoded_data.examples
        assert X.shape[1] == 2, "ProbabilisticBinaryClassifier: the shape of the input is not compatible with the classifier. " \
                                "The classifier is defined when examples are encoded by two counts: the number of successful trials " \
                                "and the total number of trials. If this is not targeted use-case and the encoding, please consider using " \
                                "another classifier."

        self.class_mapping = Util.make_binary_class_mapping(
            encoded_data.labels[label.name])
        self.label = label
        self.N_0 = int(
            np.sum(
                np.array(encoded_data.labels[label.name]) ==
                self.class_mapping[0]))
        self.N_1 = int(
            np.sum(
                np.array(encoded_data.labels[label.name]) ==
                self.class_mapping[1]))
        self.alpha_0, self.beta_0 = self._find_beta_distribution_parameters(
            X[np.nonzero(
                np.array(encoded_data.labels[self.label.name]) ==
                self.class_mapping[0])], self.N_0)
        self.alpha_1, self.beta_1 = self._find_beta_distribution_parameters(
            X[np.nonzero(
                np.array(encoded_data.labels[self.label.name]) ==
                self.class_mapping[1])], self.N_1)
Пример #2
0
    def fit(self, encoded_data: EncodedData, label_name: str, cores_for_training: int = 2):

        self.feature_names = encoded_data.feature_names

        Util.setup_pytorch(self.number_of_threads, self.random_seed)
        if "chain_names" in encoded_data.info and encoded_data.info["chain_names"] is not None and len(encoded_data.info["chain_names"]) == 2:
            self.chain_names = encoded_data.info["chain_names"]
        else:
            self.chain_names = ["chain_1", "chain_2"]

        self._make_CNN()
        self.CNN.to(device=self.device)

        self.class_mapping = Util.make_binary_class_mapping(encoded_data.labels[label_name])
        self.label_name = label_name

        self.CNN.train()

        iteration = 0
        loss_function = nn.BCEWithLogitsLoss().to(device=self.device)
        optimizer = torch.optim.Adam(self.CNN.parameters(), lr=self.learning_rate, weight_decay=self.l2_weight_decay, eps=1e-4)
        state = dict(model=copy.deepcopy(self.CNN).state_dict(), optimizer=optimizer, iteration=iteration, best_validation_loss=np.inf)
        train_data, validation_data = self._prepare_and_split_data(encoded_data)

        logging.info("ReceptorCNN: starting training.")
        while iteration < self.iteration_count:
            for examples, labels, example_ids in self._get_data_batch(train_data, self.label_name):

                # Reset gradients
                optimizer.zero_grad()

                # Calculate predictions
                logit_outputs = self.CNN(examples)

                # Calculate losses
                loss = self._compute_loss(loss_function, logit_outputs, labels)

                # Perform update
                loss.backward()
                optimizer.step()

                self.CNN.rescale_weights_for_IGM()

                iteration += 1

                # Calculate scores and loss on training set and validation set
                if iteration % self.evaluate_at == 0 or iteration == self.iteration_count or iteration == 1:
                    logging.info(f"ReceptorCNN: training - iteration {iteration}.")
                    state = self._evaluate_state(state, iteration, loss_function, validation_data)

                if iteration >= self.iteration_count:
                    self.CNN.load_state_dict(state["model"])
                    break

        logging.info("ReceptorCNN: finished training.")
Пример #3
0
    def fit(self,
            encoded_data: EncodedData,
            label: Label,
            cores_for_training: int = 2):
        self.feature_names = encoded_data.feature_names

        Util.setup_pytorch(self.number_of_threads, self.random_seed)
        self.input_size = encoded_data.examples.shape[1]

        self._make_log_reg()

        self.label = label
        self.class_mapping = Util.make_binary_class_mapping(
            encoded_data.labels[self.label.name])

        loss = np.inf

        state = {"loss": loss, "model": None}
        loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean')
        optimizer = torch.optim.SGD(self.logistic_regression.parameters(),
                                    lr=self.learning_rate)

        for iteration in range(self.iteration_count):

            # reset gradients
            optimizer.zero_grad()

            # compute predictions only for k-mers with max score
            max_logit_indices = self._get_max_logits_indices(
                encoded_data.examples)
            example_count = encoded_data.examples.shape[0]
            examples = torch.from_numpy(encoded_data.examples).float()[
                torch.arange(example_count).long(), :, max_logit_indices]
            logits = self.logistic_regression(examples)

            # compute the loss
            loss = loss_func(
                logits,
                torch.tensor(encoded_data.labels[self.label.name]).float())

            # perform update
            loss.backward()
            optimizer.step()

            # log current score and keep model for early stopping if specified
            if iteration % self.evaluate_at == 0 or iteration == self.iteration_count - 1:
                logging.info(
                    f"AtchleyKmerMILClassifier: log loss at iteration {iteration+1}/{self.iteration_count}: {loss}."
                )
                if state["loss"] < loss and self.use_early_stopping:
                    state = {
                        "loss": loss.numpy(),
                        "model": copy.deepcopy(self.logistic_regression)
                    }

            if loss < self.threshold:
                break

        logging.warning(
            f"AtchleyKmerMILClassifier: the logistic regression model did not converge."
        )

        if loss > state['loss'] and self.use_early_stopping:
            self.logistic_regression.load_state_dict(state["model"])