Пример #1
0
    def fit(self, encoded_data: EncodedData, label_name: str, cores_for_training: int = 2):

        self.feature_names = encoded_data.feature_names

        Util.setup_pytorch(self.number_of_threads, self.random_seed)
        if "chain_names" in encoded_data.info and encoded_data.info["chain_names"] is not None and len(encoded_data.info["chain_names"]) == 2:
            self.chain_names = encoded_data.info["chain_names"]
        else:
            self.chain_names = ["chain_1", "chain_2"]

        self._make_CNN()
        self.CNN.to(device=self.device)

        self.class_mapping = Util.make_binary_class_mapping(encoded_data.labels[label_name])
        self.label_name = label_name

        self.CNN.train()

        iteration = 0
        loss_function = nn.BCEWithLogitsLoss().to(device=self.device)
        optimizer = torch.optim.Adam(self.CNN.parameters(), lr=self.learning_rate, weight_decay=self.l2_weight_decay, eps=1e-4)
        state = dict(model=copy.deepcopy(self.CNN).state_dict(), optimizer=optimizer, iteration=iteration, best_validation_loss=np.inf)
        train_data, validation_data = self._prepare_and_split_data(encoded_data)

        logging.info("ReceptorCNN: starting training.")
        while iteration < self.iteration_count:
            for examples, labels, example_ids in self._get_data_batch(train_data, self.label_name):

                # Reset gradients
                optimizer.zero_grad()

                # Calculate predictions
                logit_outputs = self.CNN(examples)

                # Calculate losses
                loss = self._compute_loss(loss_function, logit_outputs, labels)

                # Perform update
                loss.backward()
                optimizer.step()

                self.CNN.rescale_weights_for_IGM()

                iteration += 1

                # Calculate scores and loss on training set and validation set
                if iteration % self.evaluate_at == 0 or iteration == self.iteration_count or iteration == 1:
                    logging.info(f"ReceptorCNN: training - iteration {iteration}.")
                    state = self._evaluate_state(state, iteration, loss_function, validation_data)

                if iteration >= self.iteration_count:
                    self.CNN.load_state_dict(state["model"])
                    break

        logging.info("ReceptorCNN: finished training.")
Пример #2
0
    def fit(self,
            encoded_data: EncodedData,
            label: Label,
            cores_for_training: int = 2):
        self.feature_names = encoded_data.feature_names

        Util.setup_pytorch(self.number_of_threads, self.random_seed)
        self.input_size = encoded_data.examples.shape[1]

        self._make_log_reg()

        self.label = label
        self.class_mapping = Util.make_binary_class_mapping(
            encoded_data.labels[self.label.name])

        loss = np.inf

        state = {"loss": loss, "model": None}
        loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean')
        optimizer = torch.optim.SGD(self.logistic_regression.parameters(),
                                    lr=self.learning_rate)

        for iteration in range(self.iteration_count):

            # reset gradients
            optimizer.zero_grad()

            # compute predictions only for k-mers with max score
            max_logit_indices = self._get_max_logits_indices(
                encoded_data.examples)
            example_count = encoded_data.examples.shape[0]
            examples = torch.from_numpy(encoded_data.examples).float()[
                torch.arange(example_count).long(), :, max_logit_indices]
            logits = self.logistic_regression(examples)

            # compute the loss
            loss = loss_func(
                logits,
                torch.tensor(encoded_data.labels[self.label.name]).float())

            # perform update
            loss.backward()
            optimizer.step()

            # log current score and keep model for early stopping if specified
            if iteration % self.evaluate_at == 0 or iteration == self.iteration_count - 1:
                logging.info(
                    f"AtchleyKmerMILClassifier: log loss at iteration {iteration+1}/{self.iteration_count}: {loss}."
                )
                if state["loss"] < loss and self.use_early_stopping:
                    state = {
                        "loss": loss.numpy(),
                        "model": copy.deepcopy(self.logistic_regression)
                    }

            if loss < self.threshold:
                break

        logging.warning(
            f"AtchleyKmerMILClassifier: the logistic regression model did not converge."
        )

        if loss > state['loss'] and self.use_early_stopping:
            self.logistic_regression.load_state_dict(state["model"])