Beispiel #1
0
class Trainer:
    def __init__(self, params):
        self.params = params
        self.prj_path = Path(__file__).parent.resolve()
        self.save_path = self.prj_path / 'pretrained' / f'{self.params.species}' / 'models'
        if not self.save_path.exists():
            self.save_path.mkdir(parents=True)
        self.device = torch.device('cpu' if self.params.gpu ==
                                   -1 else f'cuda:{params.gpu}')
        self.num_cells, self.num_genes, self.num_labels, self.graph, self.train_ids, self.test_ids, self.labels = load_data_internal(
            params)
        self.labels = self.labels.to(self.device)
        self.model = GNN(in_feats=self.params.dense_dim,
                         n_hidden=self.params.hidden_dim,
                         n_classes=self.num_labels,
                         n_layers=self.params.n_layers,
                         gene_num=self.num_genes,
                         activation=F.relu,
                         dropout=self.params.dropout).to(self.device)

        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.params.lr,
            weight_decay=self.params.weight_decay)
        self.loss_fn = nn.CrossEntropyLoss(reduction='sum')
        if self.params.num_neighbors == 0:
            self.num_neighbors = self.num_cells + self.num_genes
        else:
            self.num_neighbors = self.params.num_neighbors

        print(
            f"Train Number: {len(self.train_ids)}, Test Number: {len(self.test_ids)}"
        )

    def fit(self):
        max_test_acc, _train_acc, _epoch = 0, 0, 0
        for epoch in range(self.params.n_epochs):
            loss = self.train()
            train_correct, train_unsure = self.evaluate(
                self.train_ids, 'train')
            train_acc = train_correct / len(self.train_ids)
            test_correct, test_unsure = self.evaluate(self.test_ids, 'test')
            test_acc = test_correct / len(self.test_ids)
            if max_test_acc <= test_acc:
                final_test_correct_num = test_correct
                final_test_unsure_num = test_unsure
                _train_acc = train_acc
                _epoch = epoch
                max_test_acc = test_acc
                self.save_model()
            print(
                f">>>>Epoch {epoch:04d}: Train Acc {train_acc:.4f}, Loss {loss / len(self.train_ids):.4f}, Test correct {test_correct}, "
                f"Test unsure {test_unsure}, Test Acc {test_acc:.4f}")
            if train_acc == 1:
                break

        print(
            f"---{self.params.species} {self.params.tissue} Best test result:---"
        )
        print(
            f"Epoch {_epoch:04d}, Train Acc {_train_acc:.4f}, Test Correct Num {final_test_correct_num}, Test Total Num {len(self.test_ids)}, Test Unsure Num {final_test_unsure_num}, Test Acc {final_test_correct_num / len(self.test_ids):.4f}"
        )

    def train(self):
        self.model.train()
        total_loss = 0
        for batch, nf in enumerate(
                NeighborSampler(g=self.graph,
                                batch_size=self.params.batch_size,
                                expand_factor=self.num_neighbors,
                                num_hops=self.params.n_layers,
                                neighbor_type='in',
                                shuffle=True,
                                num_workers=8,
                                seed_nodes=self.train_ids)):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            logits = self.model(nf)
            batch_nids = nf.layer_parent_nid(-1).type(
                torch.long).to(device=self.device)
            loss = self.loss_fn(logits, self.labels[batch_nids])
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss

    def evaluate(self, ids, type='test'):
        self.model.eval()
        total_correct, total_unsure = 0, 0
        for nf in NeighborSampler(g=self.graph,
                                  batch_size=self.params.batch_size,
                                  expand_factor=self.num_cells +
                                  self.num_genes,
                                  num_hops=params.n_layers,
                                  neighbor_type='in',
                                  shuffle=True,
                                  num_workers=8,
                                  seed_nodes=ids):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            with torch.no_grad():
                logits = self.model(nf).cpu()
            batch_nids = nf.layer_parent_nid(-1).type(torch.long)
            logits = nn.functional.softmax(logits, dim=1).numpy()
            label_list = self.labels.cpu()[batch_nids]
            for pred, label in zip(logits, label_list):
                max_prob = pred.max().item()
                if max_prob < self.params.unsure_rate / self.num_labels:
                    total_unsure += 1
                elif pred.argmax().item() == label:
                    total_correct += 1

        return total_correct, total_unsure

    def save_model(self):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }

        torch.save(
            state,
            self.save_path / f"{self.params.species}-{self.params.tissue}.pt")
Beispiel #2
0
# Splits the dataset into a training and a test set
G_train, G_test, y_train, y_test = train_test_split(Gs, y, test_size=0.1)

N_train = len(G_train)
N_test = len(G_test)

# Initializes model and optimizer
model = GNN(1, n_hidden_1, n_hidden_2, n_hidden_3, n_class, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_function = nn.CrossEntropyLoss()

# Trains the model
for epoch in range(epochs):
    t = time.time()
    model.train()
    
    train_loss = 0
    correct = 0
    count = 0
    for i in range(0, N_train, batch_size):
        adj_batch = list()
        idx_batch = list()
        y_batch = list()

        ############## Task 8
        
        ##################
        # your code here #
        ##################
        for j in range(i, min(N_train, i + batch_size)):