def _evaluate_nodes(self,
                        ee,
                        lp,
                        create_api_call_loader,
                        loader,
                        neg_sampling_factor=1):

        total_loss = 0
        total_acc = 0
        count = 0

        for input_nodes, seeds, blocks in loader:
            blocks = [blk.to(self.device) for blk in blocks]

            src_embs = self._logits_batch(input_nodes, blocks)
            logits, labels = self._logits_nodes(src_embs, ee, lp,
                                                create_api_call_loader, seeds,
                                                neg_sampling_factor)

            logp = nn.functional.log_softmax(logits, 1)
            loss = nn.functional.cross_entropy(logp, labels)
            acc = compute_accuracy(logp.argmax(dim=1), labels)

            total_loss += loss.item()
            total_acc += acc
            count += 1
        return total_loss / count, total_acc / count
    def compute_acc_loss(self, node_embs_, element_embs_, labels):
        logits = self.link_predictor(node_embs_, element_embs_)

        if self.link_predictor_type == "nn":
            logp = nn.functional.log_softmax(logits, dim=1)
            loss = nn.functional.nll_loss(logp, labels)
        elif self.link_predictor_type == "inner_prod":
            loss = self.cosine_loss(node_embs_, element_embs_, labels)
            labels[labels < 0] = 0
        elif self.link_predictor_type == "l2":
            loss = self.l2_loss(node_embs_, element_embs_, labels)
            labels[labels < 0] = 0
            # num_examples = len(labels) // 2
            # anchor = node_embs_[:num_examples, :]
            # positive = element_embs_[:num_examples, :]
            # negative = element_embs_[num_examples:, :]
            # # pos_labels_ = labels[:num_examples]
            # # neg_labels_ = labels[num_examples:]
            # margin = 1.
            # triplet = nn.TripletMarginLoss(margin=margin)
            # self.target_embedder.set_margin(margin)
            # loss = triplet(anchor, positive, negative)
            # logits = (torch.norm(node_embs_ - element_embs_, keepdim=True) < 1.).float()
            # logits = torch.cat([1 - logits, logits], dim=1)
            # labels[labels < 0] = 0
        else:
            raise NotImplementedError()

        acc = compute_accuracy(logits.argmax(dim=1), labels)

        return acc, loss
    def compute_acc_loss(self, node_embs_, element_embs_, labels):
        logits = self.link_predictor(node_embs_, element_embs_)

        loss_fct = CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(logits.reshape(-1, logits.size(-1)),
                        labels.reshape(-1))

        acc = compute_accuracy(logits.argmax(dim=1), labels)

        return acc, loss
    def compute_acc_loss(self, node_embs_, element_embs_, labels):

        num_examples = len(labels) // 2
        anchor = node_embs_[:num_examples, :]
        positive = element_embs_[:num_examples, :]
        negative = element_embs_[num_examples:, :]
        labels_ = labels[:num_examples]

        loss, sim = self.link_predictor(anchor, positive, negative, labels_)
        acc = compute_accuracy(sim, labels >= 0)

        return acc, loss
    def train_all(self):
        """
        Training procedure for the model with node classifier
        :return:
        """

        for epoch in range(self.epoch, self.epochs):
            self.epoch = epoch

            start = time()

            for i, ((input_nodes_node_name, seeds_node_name, blocks_node_name),
                    (input_nodes_var_use, seeds_var_use, blocks_var_use),
                    (input_nodes_api_call, seeds_api_call, blocks_api_call)) in \
                    enumerate(zip(
                        self.loader_node_name,
                        self.loader_var_use,
                        self.loader_api_call)):

                blocks_node_name = [
                    blk.to(self.device) for blk in blocks_node_name
                ]
                blocks_var_use = [
                    blk.to(self.device) for blk in blocks_var_use
                ]
                blocks_api_call = [
                    blk.to(self.device) for blk in blocks_api_call
                ]

                logits_node_name, labels_node_name = self._logits_node_name(
                    input_nodes_node_name, seeds_node_name, blocks_node_name)

                logits_var_use, labels_var_use = self._logits_var_use(
                    input_nodes_var_use, seeds_var_use, blocks_var_use)

                logits_api_call, labels_api_call = self._logits_api_call(
                    input_nodes_api_call, seeds_api_call, blocks_api_call)

                train_acc_node_name = compute_accuracy(
                    logits_node_name.argmax(dim=1), labels_node_name)
                train_acc_var_use = compute_accuracy(
                    logits_var_use.argmax(dim=1), labels_var_use)
                train_acc_api_call = compute_accuracy(
                    logits_api_call.argmax(dim=1), labels_api_call)

                train_logits = torch.cat(
                    [logits_node_name, logits_var_use, logits_api_call], 0)
                train_labels = torch.cat(
                    [labels_node_name, labels_var_use, labels_api_call], 0)

                logp = nn.functional.log_softmax(train_logits, 1)
                loss = nn.functional.nll_loss(logp, train_labels)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self.write_summary(
                    {
                        "Loss": loss,
                        "Accuracy/train/node_name_vs_batch":
                        train_acc_node_name,
                        "Accuracy/train/var_use_vs_batch": train_acc_var_use,
                        "Accuracy/train/api_call_vs_batch": train_acc_api_call
                    }, self.batch)
                self.batch += 1

            self.eval()

            with torch.set_grad_enabled(False):

                _, val_acc_node_name, val_acc_var_use, val_acc_api_call = self._evaluate_objectives(
                    self.val_loader_node_name, self.val_loader_var_use,
                    self.val_loader_api_call, self.neg_sampling_factor)

                _, test_acc_node_name, test_acc_var_use, test_acc_api_call = self._evaluate_objectives(
                    self.test_loader_node_name, self.test_loader_var_use,
                    self.test_loader_api_call, self.neg_sampling_factor)

            self.train()

            end = time()

            self.best_score.track_best(epoch=epoch,
                                       loss=loss.item(),
                                       train_acc_node_name=train_acc_node_name,
                                       val_acc_node_name=val_acc_node_name,
                                       test_acc_node_name=test_acc_node_name,
                                       train_acc_var_use=train_acc_var_use,
                                       val_acc_var_use=val_acc_var_use,
                                       test_acc_var_use=test_acc_var_use,
                                       train_acc_api_call=train_acc_api_call,
                                       val_acc_api_call=val_acc_api_call,
                                       test_acc_api_call=test_acc_api_call,
                                       time=end - start)

            if self.do_save:
                self.save_checkpoint(self.model_base_path)

            self.write_summary(
                {
                    "Accuracy/test/node_name_vs_batch": test_acc_node_name,
                    "Accuracy/test/var_use_vs_batch": test_acc_var_use,
                    "Accuracy/test/api_call_vs_batch": test_acc_api_call,
                    "Accuracy/val/node_name_vs_batch": val_acc_node_name,
                    "Accuracy/val/var_use_vs_batch": val_acc_var_use,
                    "Accuracy/val/api_call_vs_batch": val_acc_api_call
                }, self.batch)

            self.write_hyperparams(
                {
                    "Loss/train_vs_epoch": loss,
                    "Accuracy/train/node_name_vs_epoch": train_acc_node_name,
                    "Accuracy/train/var_use_vs_epoch": train_acc_var_use,
                    "Accuracy/train/api_call_vs_epoch": train_acc_api_call,
                    "Accuracy/test/node_name_vs_epoch": test_acc_node_name,
                    "Accuracy/test/var_use_vs_epoch": test_acc_var_use,
                    "Accuracy/test/api_call_vs_epoch": test_acc_api_call,
                    "Accuracy/val/node_name_vs_epoch": val_acc_node_name,
                    "Accuracy/val/var_use_vs_epoch": val_acc_var_use,
                    "Accuracy/val/api_call_vs_epoch": val_acc_api_call
                }, self.epoch)

            self.lr_scheduler.step()