Exemplo n.º 1
0
idx = np.random.permutation(n)
idx_train = idx[:int(0.6 * n)]
idx_val = idx[int(0.6 * n):int(0.8 * n)]
idx_test = idx[int(0.8 * n):]

# Transform the numpy matrices/vectors to torch tensors
features = torch.FloatTensor(features)
y = torch.LongTensor(np.argmax(class_labels, axis=1))
adj = torch.FloatTensor(adj)
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)

# Creates the model and specifies the optimizer
model = GNN(features.shape[1], n_hidden_1, n_hidden_2, n_class, dropout_rate)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output, _ = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], y[idx_train])
    acc_train = accuracy(output[idx_train], y[idx_train])
    loss_train.backward()
    optimizer.step()

    model.eval()
    output, _ = model(features, adj)
Exemplo n.º 2
0
class Trainer:
    def __init__(self, params):
        self.params = params
        self.prj_path = Path(__file__).parent.resolve()
        self.save_path = self.prj_path / 'pretrained' / f'{self.params.species}' / 'models'
        if not self.save_path.exists():
            self.save_path.mkdir(parents=True)
        self.device = torch.device('cpu' if self.params.gpu ==
                                   -1 else f'cuda:{params.gpu}')
        self.num_cells, self.num_genes, self.num_labels, self.graph, self.train_ids, self.test_ids, self.labels = load_data_internal(
            params)
        self.labels = self.labels.to(self.device)
        self.model = GNN(in_feats=self.params.dense_dim,
                         n_hidden=self.params.hidden_dim,
                         n_classes=self.num_labels,
                         n_layers=self.params.n_layers,
                         gene_num=self.num_genes,
                         activation=F.relu,
                         dropout=self.params.dropout).to(self.device)

        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.params.lr,
            weight_decay=self.params.weight_decay)
        self.loss_fn = nn.CrossEntropyLoss(reduction='sum')
        if self.params.num_neighbors == 0:
            self.num_neighbors = self.num_cells + self.num_genes
        else:
            self.num_neighbors = self.params.num_neighbors

        print(
            f"Train Number: {len(self.train_ids)}, Test Number: {len(self.test_ids)}"
        )

    def fit(self):
        max_test_acc, _train_acc, _epoch = 0, 0, 0
        for epoch in range(self.params.n_epochs):
            loss = self.train()
            train_correct, train_unsure = self.evaluate(
                self.train_ids, 'train')
            train_acc = train_correct / len(self.train_ids)
            test_correct, test_unsure = self.evaluate(self.test_ids, 'test')
            test_acc = test_correct / len(self.test_ids)
            if max_test_acc <= test_acc:
                final_test_correct_num = test_correct
                final_test_unsure_num = test_unsure
                _train_acc = train_acc
                _epoch = epoch
                max_test_acc = test_acc
                self.save_model()
            print(
                f">>>>Epoch {epoch:04d}: Train Acc {train_acc:.4f}, Loss {loss / len(self.train_ids):.4f}, Test correct {test_correct}, "
                f"Test unsure {test_unsure}, Test Acc {test_acc:.4f}")
            if train_acc == 1:
                break

        print(
            f"---{self.params.species} {self.params.tissue} Best test result:---"
        )
        print(
            f"Epoch {_epoch:04d}, Train Acc {_train_acc:.4f}, Test Correct Num {final_test_correct_num}, Test Total Num {len(self.test_ids)}, Test Unsure Num {final_test_unsure_num}, Test Acc {final_test_correct_num / len(self.test_ids):.4f}"
        )

    def train(self):
        self.model.train()
        total_loss = 0
        for batch, nf in enumerate(
                NeighborSampler(g=self.graph,
                                batch_size=self.params.batch_size,
                                expand_factor=self.num_neighbors,
                                num_hops=self.params.n_layers,
                                neighbor_type='in',
                                shuffle=True,
                                num_workers=8,
                                seed_nodes=self.train_ids)):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            logits = self.model(nf)
            batch_nids = nf.layer_parent_nid(-1).type(
                torch.long).to(device=self.device)
            loss = self.loss_fn(logits, self.labels[batch_nids])
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss

    def evaluate(self, ids, type='test'):
        self.model.eval()
        total_correct, total_unsure = 0, 0
        for nf in NeighborSampler(g=self.graph,
                                  batch_size=self.params.batch_size,
                                  expand_factor=self.num_cells +
                                  self.num_genes,
                                  num_hops=params.n_layers,
                                  neighbor_type='in',
                                  shuffle=True,
                                  num_workers=8,
                                  seed_nodes=ids):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            with torch.no_grad():
                logits = self.model(nf).cpu()
            batch_nids = nf.layer_parent_nid(-1).type(torch.long)
            logits = nn.functional.softmax(logits, dim=1).numpy()
            label_list = self.labels.cpu()[batch_nids]
            for pred, label in zip(logits, label_list):
                max_prob = pred.max().item()
                if max_prob < self.params.unsure_rate / self.num_labels:
                    total_unsure += 1
                elif pred.argmax().item() == label:
                    total_correct += 1

        return total_correct, total_unsure

    def save_model(self):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }

        torch.save(
            state,
            self.save_path / f"{self.params.species}-{self.params.tissue}.pt")
Exemplo n.º 3
0
def main(_run, _config, _log):
    '''
    _config: dictionary; its keys and values are the variables setting in the cfg function
    _run: run object defined by Sacred, can be used to record hashable values and get some information, e.g. run id, for a run
    _log: logger object provided by Sacred, but is not very flexible, we can define loggers by oureselves
    '''

    config = dcopy(
        _config
    )  # We need this step because Sacred does not allow us to change _config object
    # But sometimes we need to add some key-value pairs to config
    torch.cuda.set_device(config['gpu_id'])

    save_source(_run)  # Source code are saved by running this line
    init_seed(config['seed'])
    logger = init_logger(log_root=_run.observers[0].dir, file_name='log.txt')

    output_folder_path = opjoin(_run.observers[0].dir,
                                config['path']['output_folder_name'])
    os.makedirs(output_folder_path, exist_ok=True)

    best_acc_list = []
    last_acc_list = []
    train_best_list = []
    train_last_list = []

    best_epoch = []

    data = load_data(config=config)
    split_iterator = range(config['data']['random_split']['num_splits']) \
                     if config['data']['random_split']['use'] \
                    else range(1)

    config['adj'] = data[0]

    for i in split_iterator:
        output_folder = opjoin(output_folder_path, str(i))
        os.makedirs(output_folder, exist_ok=True)

        if config['data']['random_split']['use']:
            data = resplit(
                dataset=config['data']['dataset'],
                data=data,
                full_sup=config['data']['full_sup'],
                num_classes=torch.unique(data[2]).shape[0],
                num_nodes=data[1].shape[0],
                num_per_class=config['data']['label_per_class'],
            )
            print(torch.sum(data[3]))

        model = GNN(config=config)

        if i == 0:
            logger.info(model)

        if config['use_gpu']:
            model.cuda()
            data = [
                each.cuda() if hasattr(each, 'cuda') else each for each in data
            ]

        optimizer = init_optimizer(
            params=model.parameters(),
            optim_type=config['optim']['type'],
            lr=config['optim']['lr'],
            weight_decay=config['optim']['weight_decay'],
            momentum=config['optim']['momemtum'])

        criterion = nn.NLLLoss()

        best_model_path = opjoin(output_folder, 'best_model.pth')
        last_model_path = opjoin(output_folder, 'last_model.pth')
        best_dict_path = opjoin(output_folder, 'best_pred_dict.pkl')
        last_dict_path = opjoin(output_folder, 'last_pred_dict.pkl')
        losses_curve_path = opjoin(output_folder, 'losses.pkl')
        accs_curve_path = opjoin(output_folder, 'accs.pkl')
        best_state_path = opjoin(output_folder, 'best_state.pkl')
        grads_path = opjoin(output_folder, 'grads.pkl')

        best_pred_dict, last_pred_dict, train_losses, train_accs, \
        val_losses, val_accs, best_state, grads, model_state = train(best_model_path,
                                                       last_model_path,
                                                       config,
                                                       criterion,
                                                       data,
                                                       logger,
                                                       model,
                                                       optimizer
                                                       )
        last_model_state, best_model_state = model_state

        losses_dict = {'train': train_losses, 'val': val_losses}

        accs_dict = {'train': train_accs, 'val': val_accs}
        logger.info(f'split_seed: {i: 04d}')
        logger.info(f'Test set results on the last model:')
        last_pred_dict = test(
            criterion,
            data,
            last_model_path,
            last_pred_dict,
            logger,
            model,
            last_model_state,
        )

        logger.info(f'Test set results on the best model:')
        if config['fastmode']:
            best_pred_dict = last_pred_dict
        else:
            best_pred_dict = test(
                criterion,
                data,
                best_model_path,
                best_pred_dict,
                logger,
                model,
                best_model_state,
            )

        logger.info('\n')

        check_before_pkl(best_pred_dict)
        with open(best_dict_path, 'wb') as f:
            pkl.dump(best_pred_dict, f)

        check_before_pkl(last_pred_dict)
        with open(last_dict_path, 'wb') as f:
            pkl.dump(last_pred_dict, f)

        check_before_pkl(losses_dict)
        with open(losses_curve_path, 'wb') as f:
            pkl.dump(losses_dict, f)

        check_before_pkl(accs_dict)
        with open(accs_curve_path, 'wb') as f:
            pkl.dump(accs_dict, f)

        check_before_pkl(best_state)
        with open(best_state_path, 'wb') as f:
            pkl.dump(best_state, f)

        check_before_pkl(grads)
        with open(grads_path, 'wb') as f:
            pkl.dump(grads, f)

        best_acc_list.append(best_pred_dict['test acc'].item())
        last_acc_list.append(last_pred_dict['test acc'].item())
        train_best_list.append(best_state['train acc'].item())
        train_last_list.append(train_accs[-1].item())
        best_epoch.append(best_state['epoch'])

    logger.info('********************* STATISTICS *********************')
    np.set_printoptions(precision=4, suppress=True)
    logger.info(f"\n"
                f"Best test acc: {best_acc_list}\n"
                f"Mean: {np.mean(best_acc_list)}\t"
                f"Std: {np.std(best_acc_list)}\n"
                f"Last test acc: {last_acc_list}\n"
                f"Mean: {np.mean(last_acc_list)}\t"
                f"Std: {np.std(last_acc_list)}\n")

    logger.info(f"best epoch: {best_epoch}")