Exemple #1
0
        count += output.size(0)
        preds = output.max(1)[1].type_as(y_batch)
        correct += torch.sum(preds.eq(y_batch).double())
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print('Epoch: {:04d}'.format(epoch+1),
              'loss_train: {:.4f}'.format(train_loss / count),
              'acc_train: {:.4f}'.format(correct / count),
              'time: {:.4f}s'.format(time.time() - t))
        
print('Optimization finished!')

# Evaluates the model
model.eval()
test_loss = 0
correct = 0
count = 0
for i in range(0, N_test, batch_size):
    adj_batch = list()
    idx_batch = list()
    y_batch = list()

    ############## Task 8
    
    ##################
    # your code here #
    ##################
    for j in range(i, min(N_test, i + batch_size)):
        n = G_test[j].number_of_nodes()
Exemple #2
0
class Trainer:
    def __init__(self, params):
        self.params = params
        self.prj_path = Path(__file__).parent.resolve()
        self.save_path = self.prj_path / 'pretrained' / f'{self.params.species}' / 'models'
        if not self.save_path.exists():
            self.save_path.mkdir(parents=True)
        self.device = torch.device('cpu' if self.params.gpu ==
                                   -1 else f'cuda:{params.gpu}')
        self.num_cells, self.num_genes, self.num_labels, self.graph, self.train_ids, self.test_ids, self.labels = load_data_internal(
            params)
        self.labels = self.labels.to(self.device)
        self.model = GNN(in_feats=self.params.dense_dim,
                         n_hidden=self.params.hidden_dim,
                         n_classes=self.num_labels,
                         n_layers=self.params.n_layers,
                         gene_num=self.num_genes,
                         activation=F.relu,
                         dropout=self.params.dropout).to(self.device)

        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.params.lr,
            weight_decay=self.params.weight_decay)
        self.loss_fn = nn.CrossEntropyLoss(reduction='sum')
        if self.params.num_neighbors == 0:
            self.num_neighbors = self.num_cells + self.num_genes
        else:
            self.num_neighbors = self.params.num_neighbors

        print(
            f"Train Number: {len(self.train_ids)}, Test Number: {len(self.test_ids)}"
        )

    def fit(self):
        max_test_acc, _train_acc, _epoch = 0, 0, 0
        for epoch in range(self.params.n_epochs):
            loss = self.train()
            train_correct, train_unsure = self.evaluate(
                self.train_ids, 'train')
            train_acc = train_correct / len(self.train_ids)
            test_correct, test_unsure = self.evaluate(self.test_ids, 'test')
            test_acc = test_correct / len(self.test_ids)
            if max_test_acc <= test_acc:
                final_test_correct_num = test_correct
                final_test_unsure_num = test_unsure
                _train_acc = train_acc
                _epoch = epoch
                max_test_acc = test_acc
                self.save_model()
            print(
                f">>>>Epoch {epoch:04d}: Train Acc {train_acc:.4f}, Loss {loss / len(self.train_ids):.4f}, Test correct {test_correct}, "
                f"Test unsure {test_unsure}, Test Acc {test_acc:.4f}")
            if train_acc == 1:
                break

        print(
            f"---{self.params.species} {self.params.tissue} Best test result:---"
        )
        print(
            f"Epoch {_epoch:04d}, Train Acc {_train_acc:.4f}, Test Correct Num {final_test_correct_num}, Test Total Num {len(self.test_ids)}, Test Unsure Num {final_test_unsure_num}, Test Acc {final_test_correct_num / len(self.test_ids):.4f}"
        )

    def train(self):
        self.model.train()
        total_loss = 0
        for batch, nf in enumerate(
                NeighborSampler(g=self.graph,
                                batch_size=self.params.batch_size,
                                expand_factor=self.num_neighbors,
                                num_hops=self.params.n_layers,
                                neighbor_type='in',
                                shuffle=True,
                                num_workers=8,
                                seed_nodes=self.train_ids)):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            logits = self.model(nf)
            batch_nids = nf.layer_parent_nid(-1).type(
                torch.long).to(device=self.device)
            loss = self.loss_fn(logits, self.labels[batch_nids])
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss

    def evaluate(self, ids, type='test'):
        self.model.eval()
        total_correct, total_unsure = 0, 0
        for nf in NeighborSampler(g=self.graph,
                                  batch_size=self.params.batch_size,
                                  expand_factor=self.num_cells +
                                  self.num_genes,
                                  num_hops=params.n_layers,
                                  neighbor_type='in',
                                  shuffle=True,
                                  num_workers=8,
                                  seed_nodes=ids):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            with torch.no_grad():
                logits = self.model(nf).cpu()
            batch_nids = nf.layer_parent_nid(-1).type(torch.long)
            logits = nn.functional.softmax(logits, dim=1).numpy()
            label_list = self.labels.cpu()[batch_nids]
            for pred, label in zip(logits, label_list):
                max_prob = pred.max().item()
                if max_prob < self.params.unsure_rate / self.num_labels:
                    total_unsure += 1
                elif pred.argmax().item() == label:
                    total_correct += 1

        return total_correct, total_unsure

    def save_model(self):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }

        torch.save(
            state,
            self.save_path / f"{self.params.species}-{self.params.tissue}.pt")
Exemple #3
0
class Runner:
    def __init__(self, params):
        self.params = params
        self.postfix = time.strftime('%d_%m_%Y') + '_' + time.strftime(
            '%H:%M:%S')
        self.prj_path = Path(__file__).parent.resolve()
        self.device = torch.device('cpu' if self.params.gpu ==
                                   -1 else f'cuda:{params.gpu}')
        if self.params.evaluate:
            self.total_cell, self.num_genes, self.num_classes, self.id2label, self.test_dict, self.map_dict, self.time = load_data(
                params)
        else:
            self.total_cell, self.num_genes, self.num_classes, self.id2label, self.test_dict, self.time = load_data(
                params)
        """
        test_dict = {
            'graph': test_graph_dict,
            'nid': test_index_dict,
            'mask': test_mask_dict
        """
        self.model = GNN(in_feats=params.dense_dim,
                         n_hidden=params.hidden_dim,
                         n_classes=self.num_classes,
                         n_layers=1,
                         gene_num=self.num_genes,
                         activation=F.relu,
                         dropout=params.dropout)
        self.load_model()
        self.num_neighbors = self.total_cell + self.num_genes
        self.model.to(self.device)

    def run(self):
        for num in self.params.test_dataset:
            tic = time.time()
            if self.params.evaluate:
                correct, total, unsure, acc, pred = self.evaluate_test(num)
                print(
                    f"{self.params.species}_{self.params.tissue} #{num} Test Acc: {acc:.4f} ({correct}/{total}), Number of Unsure Cells: {unsure}"
                )
            else:
                pred = self.inference(num)
            toc = time.time()
            print(
                f'{self.params.species}_{self.params.tissue} #{num} Time Consumed: {toc - tic + self.time:.2f} seconds.'
            )
            self.save_pred(num, pred)

    def load_model(self):
        model_path = self.prj_path / 'pretrained' / self.params.species / 'models' / f'{self.params.species}-{self.params.tissue}.pt'
        state = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(state['model'])

    def inference(self, num):
        self.model.eval()
        new_logits = torch.zeros(
            (self.test_dict['graph'][num].number_of_nodes(), self.num_classes))
        for nf in NeighborSampler(g=self.test_dict['graph'][num],
                                  batch_size=self.params.batch_size,
                                  expand_factor=self.total_cell +
                                  self.num_genes,
                                  num_hops=1,
                                  neighbor_type='in',
                                  shuffle=False,
                                  num_workers=8,
                                  seed_nodes=self.test_dict['nid'][num]):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            with torch.no_grad():
                logits = self.model(nf).cpu()
            batch_nids = nf.layer_parent_nid(-1).type(torch.long)
            new_logits[batch_nids] = logits

        new_logits = new_logits[self.test_dict['mask'][num]]
        new_logits = nn.functional.softmax(new_logits, dim=1).numpy()
        predict_label = []
        for pred in new_logits:
            pred_label = self.id2label[pred.argmax().item()]
            if pred.max().item() < self.params.unsure_rate / self.num_classes:
                # unsure
                predict_label.append('unsure')
            else:
                predict_label.append(pred_label)
        return predict_label

    def evaluate_test(self, num):
        self.model.eval()
        new_logits = torch.zeros(
            (self.test_dict['graph'][num].number_of_nodes(), self.num_classes))
        for nf in NeighborSampler(g=self.test_dict['graph'][num],
                                  batch_size=self.params.batch_size,
                                  expand_factor=self.total_cell +
                                  self.num_genes,
                                  num_hops=1,
                                  neighbor_type='in',
                                  shuffle=False,
                                  num_workers=8,
                                  seed_nodes=self.test_dict['nid'][num]):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            with torch.no_grad():
                logits = self.model(nf).cpu()
            batch_nids = nf.layer_parent_nid(-1).type(torch.long)
            new_logits[batch_nids] = logits

        new_logits = new_logits[self.test_dict['mask'][num]]
        new_logits = nn.functional.softmax(new_logits, dim=1).numpy()
        total = new_logits.shape[0]
        unsure_num, correct = 0, 0
        predict_label = []
        for pred, t_label in zip(new_logits, self.test_dict['label'][num]):
            pred_label = self.id2label[pred.argmax().item()]
            if pred.max().item() < self.params.unsure_rate / self.num_classes:
                # unsure
                unsure_num += 1
                predict_label.append('unsure')
            else:
                if pred_label in self.map_dict[num][t_label]:
                    correct += 1
                predict_label.append(pred_label)
        return correct, total, unsure_num, correct / total, predict_label

    def save_pred(self, num, pred):
        label_map = pd.read_excel(
            './map/celltype2subtype.xlsx',
            sheet_name=self.params.species,
            header=0,
            names=['species', 'old_type', 'new_type', 'new_subtype'])
        label_map = label_map.fillna('N/A', inplace=False)
        oldtype2newtype = {}
        oldtype2newsubtype = {}
        for _, old_type, new_type, new_subtype in label_map.itertuples(
                index=False):
            oldtype2newtype[old_type] = new_type
            oldtype2newsubtype[old_type] = new_subtype

        save_path = self.prj_path / self.params.save_dir
        if not save_path.exists():
            save_path.mkdir()
        if self.params.evaluate:
            df = pd.DataFrame({
                'index':
                self.test_dict['origin_id'][num],
                'original label':
                self.test_dict['label'][num],
                'cell_type': [oldtype2newtype.get(p, p) for p in pred],
                'cell_subtype': [oldtype2newsubtype.get(p, p) for p in pred]
            })
        else:
            df = pd.DataFrame({
                'index':
                self.test_dict['origin_id'][num],
                'cell_type': [oldtype2newtype.get(p, p) for p in pred],
                'cell_subtype': [oldtype2newsubtype.get(p, p) for p in pred]
            })
        df.to_csv(save_path /
                  (self.params.species + f"_{self.params.tissue}_{num}.csv"),
                  index=False)
        print(
            f"output has been stored in {self.params.species}_{self.params.tissue}_{num}.csv"
        )