コード例 #1
0
    def init_model(self, model_type="mpad", lr=0.1, **kwargs):
        # Store model type
        self.model_type = model_type.lower()
        # Initiate model
        if model_type.lower() == "mpad":
            self.model = MPAD(**kwargs)
        else:
            raise AssertionError("Currently only MPAD is supported as model")

        self.model_args = kwargs

        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, step_size=50, gamma=0.5
        )

        self.criterion = torch.nn.CrossEntropyLoss()
コード例 #2
0
ファイル: main.py プロジェクト: rationalisa/mpad
    y_val = [y[i] for i in val_index]

    adj_test = [adj[i] for i in test_index]
    features_test = [features[i] for i in test_index]
    y_test = [y[i] for i in test_index]

    adj_train, features_train, batch_n_graphs_train, y_train = generate_batches(adj_train, features_train, y_train, args.batch_size, args.use_master_node)
    adj_val, features_val, batch_n_graphs_val, y_val = generate_batches(adj_val, features_val, y_val, args.batch_size, args.use_master_node)
    adj_test, features_test, batch_n_graphs_test, y_test = generate_batches(adj_test, features_test, y_test, args.batch_size, args.use_master_node)

    n_train_batches = ceil(n_train/args.batch_size)
    n_val_batches = ceil(n_val/args.batch_size)
    n_test_batches = ceil(n_test/args.batch_size)

    # Model and optimizer
    model = MPAD(embeddings.shape[1], args.message_passing_layers, args.hidden, args.penultimate, nclass, args.dropout, embeddings, args.use_master_node)

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(parameters, lr=args.lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    if args.cuda:
        model.cuda()
        adj_train = [x.cuda() for x in adj_train]
        features_train = [x.cuda() for x in features_train]
        batch_n_graphs_train = [x.cuda() for x in batch_n_graphs_train]
        y_train = [x.cuda() for x in y_train]
        adj_val = [x.cuda() for x in adj_val]
        features_val = [x.cuda() for x in features_val]
        batch_n_graphs_val = [x.cuda() for x in batch_n_graphs_val]
        y_val = [x.cuda() for x in y_val]
コード例 #3
0
class Learner:
    def __init__(self, experiment_name, device, multi_label):

        self.experiment_name = experiment_name
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.device = device
        self.writer = None
        self.train_step = 0
        self.multi_label = multi_label
        self.best_score = 0

        self.graph_preprocess_args = None
        self.epoch = -1
        self.model_save_dir = os.path.join(experiment_name, "models")
        self.model_type = None
        self.model_args = None
        self.log_dir = None

        os.makedirs(self.model_save_dir, exist_ok=True)
        self.best_model_path = os.path.join(self.model_save_dir, "model_best.pt")

    def set_graph_preprocessing_args(self, args):

        assert set(list(args.keys())) == set(
            GRAPH_PREPROCESS_ARGS), "Error, trying to set graph preprocessing arguments, got keys: {}, \n expected: {}".format(
            list(args.keys()), GRAPH_PREPROCESS_ARGS)

        self.graph_preprocess_args = args

    def get_graph_preprocessing_args(self):
        return self.graph_preprocess_args

    def init_model(self, model_type="mpad", lr=0.1, **kwargs):
        # Store model type
        self.model_type = model_type.lower()
        # Initiate model
        if model_type.lower() == "mpad":
            self.model = MPAD(**kwargs)
        else:
            raise AssertionError("Currently only MPAD is supported as model")

        self.model_args = kwargs

        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.scheduler = optim.lr_scheduler.StepLR(
            self.optimizer, step_size=50, gamma=0.5
        )

        self.criterion = torch.nn.CrossEntropyLoss()

    def train_epoch(self, dataloader, eval_every):
        self.epoch += 1
        self.model.train()
        total_iters = -1

        with tqdm(initial=0, total=eval_every) as pbar_train:
            for batch_ix, batch in enumerate(dataloader):
                total_iters += 1

                batch = (t.to(self.device) for t in batch)
                A, nodes, y, n_graphs = batch

                preds = self.model(nodes, A, n_graphs)

                loss = self.criterion(preds, y)

                self.optimizer.zero_grad()
                loss.backward()

                # grad norm clipping?
                self.optimizer.step()
                self.scheduler.step()

                pbar_train.update(1)
                pbar_train.set_description(
                    "Training step {} -> loss: {}".format(total_iters + 1, loss.item())
                )

                if (total_iters + 1) % eval_every == 0:
                    # Stop training
                    break

    def compute_metrics(self, y_pred, y_true):

        if self.multi_label:
            raise NotImplementedError()
        else:
            # Compute weighted average of F1 score
            y_pred = np.argmax(y_pred, axis=1)
            class_report = classification_report(y_true, y_pred, output_dict=True)
            return class_report["weighted avg"]["f1-score"]

    def save_model(self, is_best):

        to_save = {
            "experiment_name": self.experiment_name,
            "model_type": self.model_type,
            "graph_preprocess_args":self.graph_preprocess_args,
            "epoch": self.epoch,
            "model_args": self.model_args,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        # Save model indexed by epoch nr
        save_path = os.path.join(
            self.model_save_dir, self.experiment_name + "_{}.pt".format(self.epoch)
        )
        torch.save(to_save, save_path)

        if is_best:
            # Save best model separately
            torch.save(to_save, self.best_model_path)

    def load_model(self, path, lr=0.1):

        to_load = torch.load(path)

        self.epoch = to_load["epoch"]
        # Set up architecture of model
        self.init_model(
            model_type=to_load["model_type"],
            lr=lr,
            **to_load["model_args"]  # pass as kwargs
        )
        # Store kwargs for
        self.set_graph_preprocessing_args(to_load["graph_preprocess_args"])

        self.model.load_state_dict(to_load["state_dict"])
        self.optimizer.load_state_dict(to_load["optimizer"])

    def load_best_model(self):
        # Load the best model of the current experiment
        self.load_model(self.best_model_path)

    def evaluate(self, dataloader, save_model=True):

        self.model.eval()
        y_pred = []
        y_true = []
        running_loss = 0

        ######################################
        # Infer the model on the dataset
        ######################################
        with tqdm(initial=0, total=len(dataloader)) as pbar_eval:
            with torch.no_grad():
                for batch_idx, batch in enumerate(dataloader):
                    batch = (t.to(self.device) for t in batch)
                    A, nodes, y, n_graphs = batch

                    preds = self.model(nodes, A, n_graphs)
                    loss = self.criterion(preds, y)
                    running_loss += loss.item()
                    # store predictions and targets
                    y_pred.extend(list(preds.cpu().detach().numpy()))
                    y_true.extend(list(np.round(y.cpu().detach().numpy())))

                    pbar_eval.update(1)
                    pbar_eval.set_description(
                        "Eval step {} -> loss: {}".format(batch_idx + 1, loss.item())
                    )

        ######################################
        # Compute metrics
        ######################################
        f1 = self.compute_metrics(y_pred, y_true)

        if f1 > self.best_score and save_model:
            print("Saving new best model with F1 score {:.03f}".format(f1))
            self.best_score = f1
            self.save_model(is_best=True)
        else:
            print(
                "Current F1-score: {:.03f}, previous best: {:.03f}".format(
                    f1, self.best_score
                )
            )