train_acc.update(accuracy(output.data, y_train[i].data), output.size(0)) # list_of_train_acc.append(loss.item()) train_f1.update(def_f1_score(output.data, y_train[i].data), output.size(0)) # list_of_train_f1.append(loss.item()) train_precision.update(def_precision(output.data, y_train[i].data), output.size(0)) train_recall.update(def_recall(output.data, y_train[i].data), output.size(0)) scheduler.step() list_of_train_losses.append(train_loss.avg) list_of_train_acc.append(train_acc.avg) list_of_train_f1.append(train_f1.avg) # Evaluate on validation set model.eval() val_loss = AverageMeter() val_acc = AverageMeter() val_f1 = AverageMeter() val_recall = AverageMeter() val_precision = AverageMeter() for i in range(n_val_batches): output, loss = test(adj_val[i], features_val[i], batch_n_graphs_val[i], y_val[i]) val_loss.update(loss.item(), output.size(0)) val_acc.update(accuracy(output.data, y_val[i].data), output.size(0)) val_f1.update(def_f1_score(output.data, y_val[i].data), output.size(0)) val_precision.update(def_precision(output.data, y_val[i].data), output.size(0))
for epoch in range(args.epochs): scheduler.step() start = time.time() model.train() train_loss = AverageMeter() train_acc = AverageMeter() # Train for one epoch for i in range(n_train_batches): output, loss = train(epoch, adj_train[i], features_train[i], batch_n_graphs_train[i], y_train[i]) train_loss.update(loss.item(), output.size(0)) train_acc.update(accuracy(output.data, y_train[i].data), output.size(0)) # Evaluate on validation set model.eval() val_loss = AverageMeter() val_acc = AverageMeter() for i in range(n_val_batches): output, loss = test(adj_val[i], features_val[i], batch_n_graphs_val[i], y_val[i]) val_loss.update(loss.item(), output.size(0)) val_acc.update(accuracy(output.data, y_val[i].data), output.size(0)) # Print results print("Cross-val iter:", '%02d' % it, "epoch:", '%03d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss.avg), "train_acc=", "{:.5f}".format(train_acc.avg), "val_loss=", "{:.5f}".format(val_loss.avg), "val_acc=", "{:.5f}".format(val_acc.avg), "time=", "{:.5f}".format(time.time() - start)) # Remember best accuracy and save checkpoint is_best = val_acc.avg >= best_acc
class Learner: def __init__(self, experiment_name, device, multi_label): self.experiment_name = experiment_name self.model = None self.optimizer = None self.scheduler = None self.device = device self.writer = None self.train_step = 0 self.multi_label = multi_label self.best_score = 0 self.graph_preprocess_args = None self.epoch = -1 self.model_save_dir = os.path.join(experiment_name, "models") self.model_type = None self.model_args = None self.log_dir = None os.makedirs(self.model_save_dir, exist_ok=True) self.best_model_path = os.path.join(self.model_save_dir, "model_best.pt") def set_graph_preprocessing_args(self, args): assert set(list(args.keys())) == set( GRAPH_PREPROCESS_ARGS), "Error, trying to set graph preprocessing arguments, got keys: {}, \n expected: {}".format( list(args.keys()), GRAPH_PREPROCESS_ARGS) self.graph_preprocess_args = args def get_graph_preprocessing_args(self): return self.graph_preprocess_args def init_model(self, model_type="mpad", lr=0.1, **kwargs): # Store model type self.model_type = model_type.lower() # Initiate model if model_type.lower() == "mpad": self.model = MPAD(**kwargs) else: raise AssertionError("Currently only MPAD is supported as model") self.model_args = kwargs self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.scheduler = optim.lr_scheduler.StepLR( self.optimizer, step_size=50, gamma=0.5 ) self.criterion = torch.nn.CrossEntropyLoss() def train_epoch(self, dataloader, eval_every): self.epoch += 1 self.model.train() total_iters = -1 with tqdm(initial=0, total=eval_every) as pbar_train: for batch_ix, batch in enumerate(dataloader): total_iters += 1 batch = (t.to(self.device) for t in batch) A, nodes, y, n_graphs = batch preds = self.model(nodes, A, n_graphs) loss = self.criterion(preds, y) self.optimizer.zero_grad() loss.backward() # grad norm clipping? self.optimizer.step() self.scheduler.step() pbar_train.update(1) pbar_train.set_description( "Training step {} -> loss: {}".format(total_iters + 1, loss.item()) ) if (total_iters + 1) % eval_every == 0: # Stop training break def compute_metrics(self, y_pred, y_true): if self.multi_label: raise NotImplementedError() else: # Compute weighted average of F1 score y_pred = np.argmax(y_pred, axis=1) class_report = classification_report(y_true, y_pred, output_dict=True) return class_report["weighted avg"]["f1-score"] def save_model(self, is_best): to_save = { "experiment_name": self.experiment_name, "model_type": self.model_type, "graph_preprocess_args":self.graph_preprocess_args, "epoch": self.epoch, "model_args": self.model_args, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), } # Save model indexed by epoch nr save_path = os.path.join( self.model_save_dir, self.experiment_name + "_{}.pt".format(self.epoch) ) torch.save(to_save, save_path) if is_best: # Save best model separately torch.save(to_save, self.best_model_path) def load_model(self, path, lr=0.1): to_load = torch.load(path) self.epoch = to_load["epoch"] # Set up architecture of model self.init_model( model_type=to_load["model_type"], lr=lr, **to_load["model_args"] # pass as kwargs ) # Store kwargs for self.set_graph_preprocessing_args(to_load["graph_preprocess_args"]) self.model.load_state_dict(to_load["state_dict"]) self.optimizer.load_state_dict(to_load["optimizer"]) def load_best_model(self): # Load the best model of the current experiment self.load_model(self.best_model_path) def evaluate(self, dataloader, save_model=True): self.model.eval() y_pred = [] y_true = [] running_loss = 0 ###################################### # Infer the model on the dataset ###################################### with tqdm(initial=0, total=len(dataloader)) as pbar_eval: with torch.no_grad(): for batch_idx, batch in enumerate(dataloader): batch = (t.to(self.device) for t in batch) A, nodes, y, n_graphs = batch preds = self.model(nodes, A, n_graphs) loss = self.criterion(preds, y) running_loss += loss.item() # store predictions and targets y_pred.extend(list(preds.cpu().detach().numpy())) y_true.extend(list(np.round(y.cpu().detach().numpy()))) pbar_eval.update(1) pbar_eval.set_description( "Eval step {} -> loss: {}".format(batch_idx + 1, loss.item()) ) ###################################### # Compute metrics ###################################### f1 = self.compute_metrics(y_pred, y_true) if f1 > self.best_score and save_model: print("Saving new best model with F1 score {:.03f}".format(f1)) self.best_score = f1 self.save_model(is_best=True) else: print( "Current F1-score: {:.03f}, previous best: {:.03f}".format( f1, self.best_score ) )