예제 #1
0
    def train(self):
        self.model.train()
        total_loss = 0
        for batch, nf in enumerate(
                NeighborSampler(g=self.graph,
                                batch_size=self.params.batch_size,
                                expand_factor=self.num_neighbors,
                                num_hops=self.params.n_layers,
                                neighbor_type='in',
                                shuffle=True,
                                num_workers=8,
                                seed_nodes=self.train_ids)):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            logits = self.model(nf)
            batch_nids = nf.layer_parent_nid(-1).type(
                torch.long).to(device=self.device)
            loss = self.loss_fn(logits, self.labels[batch_nids])
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return total_loss
예제 #2
0
    def evaluate(self, ids, type='test'):
        self.model.eval()
        total_correct, total_unsure = 0, 0
        for nf in NeighborSampler(g=self.graph,
                                  batch_size=self.params.batch_size,
                                  expand_factor=self.num_cells +
                                  self.num_genes,
                                  num_hops=params.n_layers,
                                  neighbor_type='in',
                                  shuffle=True,
                                  num_workers=8,
                                  seed_nodes=ids):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            with torch.no_grad():
                logits = self.model(nf).cpu()
            batch_nids = nf.layer_parent_nid(-1).type(torch.long)
            logits = nn.functional.softmax(logits, dim=1).numpy()
            label_list = self.labels.cpu()[batch_nids]
            for pred, label in zip(logits, label_list):
                max_prob = pred.max().item()
                if max_prob < self.params.unsure_rate / self.num_labels:
                    total_unsure += 1
                elif pred.argmax().item() == label:
                    total_correct += 1

        return total_correct, total_unsure
예제 #3
0
    def inference(self, num):
        self.model.eval()
        new_logits = torch.zeros(
            (self.test_dict['graph'][num].number_of_nodes(), self.num_classes))
        for nf in NeighborSampler(g=self.test_dict['graph'][num],
                                  batch_size=self.params.batch_size,
                                  expand_factor=self.total_cell +
                                  self.num_genes,
                                  num_hops=1,
                                  neighbor_type='in',
                                  shuffle=False,
                                  num_workers=8,
                                  seed_nodes=self.test_dict['nid'][num]):
            nf.copy_from_parent(
            )  # Copy node/edge features from the parent graph.
            with torch.no_grad():
                logits = self.model(nf).cpu()
            batch_nids = nf.layer_parent_nid(-1).type(torch.long)
            new_logits[batch_nids] = logits

        new_logits = new_logits[self.test_dict['mask'][num]]
        new_logits = nn.functional.softmax(new_logits, dim=1).numpy()
        predict_label = []
        for pred in new_logits:
            pred_label = self.id2label[pred.argmax().item()]
            if pred.max().item() < self.params.unsure_rate / self.num_classes:
                # unsure
                predict_label.append('unsure')
            else:
                predict_label.append(pred_label)
        return predict_label
예제 #4
0
    def __init__(self,
                 graph,
                 labels,
                 batch_size,
                 num_hops,
                 seed_nodes,
                 sample_type='neighbor',
                 num_neighbors=8,
                 num_worker=32):
        self.graph = graph
        self.labels = labels
        self.batch_size = batch_size
        self.type = sample_type
        self.num_hops = num_hops
        self.seed_nodes = seed_nodes
        self.num_neighbors = num_neighbors
        self.num_worker = num_worker

        self.device_num = torch.cuda.device_count()
        if self.device_num == 0:
            self.device_num = 1  # cpu
        per_worker_batch = int(self.batch_size / self.device_num)
        if self.type == "neighbor":
            self.sampler = NeighborSampler(self.graph,
                                           per_worker_batch,
                                           self.num_neighbors,
                                           neighbor_type='in',
                                           shuffle=True,
                                           num_workers=self.num_worker,
                                           num_hops=self.num_hops,
                                           seed_nodes=self.seed_nodes,
                                           prefetch=True)
        else:
            self.sampler = None
            raise RuntimeError("Currently only support Neighbor Sampling")

        self.sampler_iter = None
예제 #5
0
    def train(self):
        # initialize
        dur = []
        train_losses = []  # per mini-batch
        train_accuracies = []
        val_losses = []
        val_accuracies = []

        for epoch in range(self.epochs):
            train_losses_temp = []
            train_accuracies_temp = []
            val_losses_temp = []
            val_accuracies_temp = []
            if use_tensorboardx:
                for i, (name, param) in enumerate(self.model.named_parameters()):
                    self.writer.add_histogram(name, param, epoch)
            # minibatch train
            train_num_correct = 0  # number of correct prediction in validation set
            train_total_losses = 0  # total cross entropy loss
            if epoch >= 2:
                t0 = time.time()
            for nf in NeighborSampler(self.g, 
                                        batch_size=self.batch_size,
                                        expand_factor=self.num_neighbors,
                                        neighbor_type='in',
                                        shuffle=True,
                                        num_hops=self.n_layers,
                                        add_self_loop=False,
                                        seed_nodes=self.train_id):
                # update the aggregate history of all nodes in each layer
                for i in range(self.n_layers):
                    agg_history_str = 'agg_history_{}'.format(i)
                    self.g.pull(nf.layer_parent_nid(i+1), fn.copy_src(src='history_{}'.format(i), out='m'),
                        fn.sum(msg='m', out=agg_history_str))

                # Copy the features from the original graph to the nodeflow graph
                node_embed_names = [['features', 'history_0']]
                for i in range(1, self.n_layers):
                    node_embed_names.append(['history_{}'.format(i), 'agg_history_{}'.format(i-1), 'subg_norm', 'norm'])
                node_embed_names.append(['agg_history_{}'.format(self.n_layers-1), 'subg_norm', 'norm'])
                edge_embed_names = [['edge_features']]
                nf.copy_from_parent(node_embed_names=node_embed_names, 
                                    edge_embed_names=edge_embed_names)

                # Forward Pass, Calculate Loss and Accuracy
                self.model.train() # set to train mode
                logits = self.model(nf)
                batch_node_ids = nf.layer_parent_nid(-1)
                batch_size = len(batch_node_ids)
                batch_labels = self.labels[batch_node_ids]
                mini_batch_accuracy = accuracy(logits, batch_labels)
                train_num_correct += mini_batch_accuracy * batch_size
                train_loss = self.loss_fn(logits, batch_labels)
                train_total_losses += (train_loss.item() * batch_size)

                # Train
                self.optimizer.zero_grad()
                train_loss.backward()
                self.optimizer.step()

                node_embed_names = [['history_{}'.format(i)] for i in range(self.n_layers)]
                node_embed_names.append([])

                # Copy the udpated features from the nodeflow graph to the original graph
                nf.copy_to_parent(node_embed_names=node_embed_names)

            # loss and accuracy of this epoch
            train_average_loss = train_total_losses / len(self.train_id)
            train_losses.append(train_average_loss)
            train_accuracy = train_num_correct / len(self.train_id)
            train_accuracies.append(train_accuracy)

            # copy parameter to the inference model
            if epoch >= 2:
                dur.append(time.time() - t0)

            # Validation
            val_num_correct = 0  # number of correct prediction in validation set
            val_total_losses = 0  # total cross entropy loss
            for nf in NeighborSampler(self.g, 
                                        batch_size=len(self.val_id),
                                        expand_factor=self.g.number_of_nodes(),
                                        neighbor_type='in',
                                        num_hops=self.n_layers,
                                        seed_nodes=self.val_id,
                                        add_self_loop=False, 
                                        num_workers=self.num_cpu):
                # in testing/validation, no need to update the history
                node_embed_names = [['features']]
                edge_embed_names = [['edge_features']]
                for i in range(self.n_layers):
                    node_embed_names.append(['norm', 'subg_norm'])
                nf.copy_from_parent(node_embed_names=node_embed_names, 
                                    edge_embed_names=edge_embed_names)
                self.model_infer.load_state_dict(self.model.state_dict())
                logits, embeddings = self.model_infer(nf)
                batch_node_ids = nf.layer_parent_nid(-1)
                batch_size = len(batch_node_ids)
                batch_labels = self.labels[batch_node_ids]
                mini_batch_accuracy = accuracy(logits, batch_labels)
                val_num_correct += mini_batch_accuracy * batch_size
                mini_batch_val_loss = self.loss_fn(logits, batch_labels)
                val_total_losses += (mini_batch_val_loss.item() * batch_size)

            # loss and accuracy of this epoch
            val_average_loss = val_total_losses / len(self.val_id)
            val_losses.append(val_average_loss)
            val_accuracy = val_num_correct / len(self.val_id)
            val_accuracies.append(val_accuracy)

            # early stopping
            self.early_stopping(val_average_loss, self.model_infer)
            if self.early_stopping.early_stop:
                logging.info("Early stopping")
                break

            # if epoch == 25:
            #     # switch to sgd with large learning rate
            #     # https://arxiv.org/abs/1706.02677
            #     self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001)
            #     self.sched = torch.optim.lr_scheduler.LambdaLR(self.optimizer, self.sched_lambda['decay'])
            # elif epoch < 25:
            #     self.sched.step()
            logging.info("Epoch {:05d} | Time(s) {:.4f} | TrainLoss {:.4f} | TrainAcc {:.4f} |"
                " ValLoss {:.4f} | ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
                format(epoch, np.mean(dur), train_average_loss, train_accuracy,
                        val_average_loss, val_accuracy, self.n_edges / np.mean(dur) / 1000))

        # embeddings visualization
        if use_tensorboardx:
            self.writer.add_embedding(embeddings, global_step=epoch, metadata=batch_labels)

        # load the last checkpoint with the best model
        self.model.load_state_dict(torch.load(os.path.join(self.model_dir, 'checkpoint.pt')))

        # # logging.info()
        # acc = self.evaluate(self.features, self.labels, self.test_mask)
        # logging.info("Test Accuracy {:.4f}".format(acc))

        self.plot(train_losses, val_losses, train_accuracies, val_accuracies)