def _evaluate_nodes(self, ee, lp, create_api_call_loader, loader, neg_sampling_factor=1): total_loss = 0 total_acc = 0 count = 0 for input_nodes, seeds, blocks in loader: blocks = [blk.to(self.device) for blk in blocks] src_embs = self._logits_batch(input_nodes, blocks) logits, labels = self._logits_nodes(src_embs, ee, lp, create_api_call_loader, seeds, neg_sampling_factor) logp = nn.functional.log_softmax(logits, 1) loss = nn.functional.cross_entropy(logp, labels) acc = compute_accuracy(logp.argmax(dim=1), labels) total_loss += loss.item() total_acc += acc count += 1 return total_loss / count, total_acc / count
def compute_acc_loss(self, node_embs_, element_embs_, labels): logits = self.link_predictor(node_embs_, element_embs_) if self.link_predictor_type == "nn": logp = nn.functional.log_softmax(logits, dim=1) loss = nn.functional.nll_loss(logp, labels) elif self.link_predictor_type == "inner_prod": loss = self.cosine_loss(node_embs_, element_embs_, labels) labels[labels < 0] = 0 elif self.link_predictor_type == "l2": loss = self.l2_loss(node_embs_, element_embs_, labels) labels[labels < 0] = 0 # num_examples = len(labels) // 2 # anchor = node_embs_[:num_examples, :] # positive = element_embs_[:num_examples, :] # negative = element_embs_[num_examples:, :] # # pos_labels_ = labels[:num_examples] # # neg_labels_ = labels[num_examples:] # margin = 1. # triplet = nn.TripletMarginLoss(margin=margin) # self.target_embedder.set_margin(margin) # loss = triplet(anchor, positive, negative) # logits = (torch.norm(node_embs_ - element_embs_, keepdim=True) < 1.).float() # logits = torch.cat([1 - logits, logits], dim=1) # labels[labels < 0] = 0 else: raise NotImplementedError() acc = compute_accuracy(logits.argmax(dim=1), labels) return acc, loss
def compute_acc_loss(self, node_embs_, element_embs_, labels): logits = self.link_predictor(node_embs_, element_embs_) loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) acc = compute_accuracy(logits.argmax(dim=1), labels) return acc, loss
def compute_acc_loss(self, node_embs_, element_embs_, labels): num_examples = len(labels) // 2 anchor = node_embs_[:num_examples, :] positive = element_embs_[:num_examples, :] negative = element_embs_[num_examples:, :] labels_ = labels[:num_examples] loss, sim = self.link_predictor(anchor, positive, negative, labels_) acc = compute_accuracy(sim, labels >= 0) return acc, loss
def train_all(self): """ Training procedure for the model with node classifier :return: """ for epoch in range(self.epoch, self.epochs): self.epoch = epoch start = time() for i, ((input_nodes_node_name, seeds_node_name, blocks_node_name), (input_nodes_var_use, seeds_var_use, blocks_var_use), (input_nodes_api_call, seeds_api_call, blocks_api_call)) in \ enumerate(zip( self.loader_node_name, self.loader_var_use, self.loader_api_call)): blocks_node_name = [ blk.to(self.device) for blk in blocks_node_name ] blocks_var_use = [ blk.to(self.device) for blk in blocks_var_use ] blocks_api_call = [ blk.to(self.device) for blk in blocks_api_call ] logits_node_name, labels_node_name = self._logits_node_name( input_nodes_node_name, seeds_node_name, blocks_node_name) logits_var_use, labels_var_use = self._logits_var_use( input_nodes_var_use, seeds_var_use, blocks_var_use) logits_api_call, labels_api_call = self._logits_api_call( input_nodes_api_call, seeds_api_call, blocks_api_call) train_acc_node_name = compute_accuracy( logits_node_name.argmax(dim=1), labels_node_name) train_acc_var_use = compute_accuracy( logits_var_use.argmax(dim=1), labels_var_use) train_acc_api_call = compute_accuracy( logits_api_call.argmax(dim=1), labels_api_call) train_logits = torch.cat( [logits_node_name, logits_var_use, logits_api_call], 0) train_labels = torch.cat( [labels_node_name, labels_var_use, labels_api_call], 0) logp = nn.functional.log_softmax(train_logits, 1) loss = nn.functional.nll_loss(logp, train_labels) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.write_summary( { "Loss": loss, "Accuracy/train/node_name_vs_batch": train_acc_node_name, "Accuracy/train/var_use_vs_batch": train_acc_var_use, "Accuracy/train/api_call_vs_batch": train_acc_api_call }, self.batch) self.batch += 1 self.eval() with torch.set_grad_enabled(False): _, val_acc_node_name, val_acc_var_use, val_acc_api_call = self._evaluate_objectives( self.val_loader_node_name, self.val_loader_var_use, self.val_loader_api_call, self.neg_sampling_factor) _, test_acc_node_name, test_acc_var_use, test_acc_api_call = self._evaluate_objectives( self.test_loader_node_name, self.test_loader_var_use, self.test_loader_api_call, self.neg_sampling_factor) self.train() end = time() self.best_score.track_best(epoch=epoch, loss=loss.item(), train_acc_node_name=train_acc_node_name, val_acc_node_name=val_acc_node_name, test_acc_node_name=test_acc_node_name, train_acc_var_use=train_acc_var_use, val_acc_var_use=val_acc_var_use, test_acc_var_use=test_acc_var_use, train_acc_api_call=train_acc_api_call, val_acc_api_call=val_acc_api_call, test_acc_api_call=test_acc_api_call, time=end - start) if self.do_save: self.save_checkpoint(self.model_base_path) self.write_summary( { "Accuracy/test/node_name_vs_batch": test_acc_node_name, "Accuracy/test/var_use_vs_batch": test_acc_var_use, "Accuracy/test/api_call_vs_batch": test_acc_api_call, "Accuracy/val/node_name_vs_batch": val_acc_node_name, "Accuracy/val/var_use_vs_batch": val_acc_var_use, "Accuracy/val/api_call_vs_batch": val_acc_api_call }, self.batch) self.write_hyperparams( { "Loss/train_vs_epoch": loss, "Accuracy/train/node_name_vs_epoch": train_acc_node_name, "Accuracy/train/var_use_vs_epoch": train_acc_var_use, "Accuracy/train/api_call_vs_epoch": train_acc_api_call, "Accuracy/test/node_name_vs_epoch": test_acc_node_name, "Accuracy/test/var_use_vs_epoch": test_acc_var_use, "Accuracy/test/api_call_vs_epoch": test_acc_api_call, "Accuracy/val/node_name_vs_epoch": val_acc_node_name, "Accuracy/val/var_use_vs_epoch": val_acc_var_use, "Accuracy/val/api_call_vs_epoch": val_acc_api_call }, self.epoch) self.lr_scheduler.step()