class AGEM:
    def __init__(self, device, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=1,
                                         max_length=kwargs.get('max_length'),
                                         device=device,
                                         hebbian=kwargs.get('hebbian'))
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3)
        logger.info('Loaded {} as the model'.format(
            self.model.__class__.__name__))

        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=self.lr)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def save_model(self, model_path):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint)

    def compute_grad(self, orig_grad, ref_grad):
        with torch.no_grad():
            flat_orig_grad = torch.cat([torch.flatten(x) for x in orig_grad])
            flat_ref_grad = torch.cat([torch.flatten(x) for x in ref_grad])
            dot_product = torch.dot(flat_orig_grad, flat_ref_grad)
            if dot_product >= 0:
                return orig_grad
            proj_component = dot_product / torch.dot(flat_ref_grad,
                                                     flat_ref_grad)
            modified_grad = [
                o - proj_component * r for (o, r) in zip(orig_grad, ref_grad)
            ]
            return modified_grad

    def train(self, dataloader, n_epochs, log_freq):

        self.model.train()

        for epoch in range(n_epochs):
            all_losses, all_predictions, all_labels = [], [], []
            iter = 0

            for text, label, candidates in dataloader:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                    text, label, candidates)

                input_dict = self.model.encode_text(
                    list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)
                targets = torch.tensor(ranking_label).float().unsqueeze(1).to(
                    self.device)
                loss = self.loss_fn(output, targets)
                self.optimizer.zero_grad()

                params = [
                    p for p in self.model.parameters() if p.requires_grad
                ]
                orig_grad = torch.autograd.grad(loss, params)

                mini_batch_size = len(label)
                replay_freq = self.replay_every // mini_batch_size
                replay_steps = int(self.replay_every * self.replay_rate /
                                   mini_batch_size)

                if self.replay_rate != 0 and (iter + 1) % replay_freq == 0:
                    ref_grad_sum = None
                    for _ in range(replay_steps):
                        ref_text, ref_label, ref_candidates = self.memory.read_batch(
                            batch_size=mini_batch_size)
                        replicated_ref_text, replicated_ref_relations, ref_ranking_label = datasets.utils.replicate_rel_data(
                            ref_text, ref_label, ref_candidates)
                        ref_input_dict = self.model.encode_text(
                            list(
                                zip(replicated_ref_text,
                                    replicated_ref_relations)))
                        ref_output = self.model(ref_input_dict)
                        ref_targets = torch.tensor(
                            ref_ranking_label).float().unsqueeze(1).to(
                                self.device)
                        ref_loss = self.loss_fn(ref_output, ref_targets)
                        ref_grad = torch.autograd.grad(ref_loss, params)
                        if ref_grad_sum is None:
                            ref_grad_sum = ref_grad
                        else:
                            ref_grad_sum = [
                                x + y
                                for (x, y) in zip(ref_grad, ref_grad_sum)
                            ]
                    final_grad = self.compute_grad(orig_grad, ref_grad_sum)
                else:
                    final_grad = orig_grad

                for param, grad in zip(params, final_grad):
                    param.grad = grad.data

                self.optimizer.step()

                loss = loss.item()
                pred, true_labels = models.utils.make_rel_prediction(
                    output, ranking_label)
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(true_labels.tolist())
                iter += 1
                self.memory.write_batch(text, label, candidates)

                if iter % log_freq == 0:
                    acc = models.utils.calculate_accuracy(
                        all_predictions, all_labels)
                    logger.info(
                        'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}'.
                        format(epoch + 1, np.mean(all_losses), acc))
                    all_losses, all_predictions, all_labels = [], [], []

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for text, label, candidates in dataloader:
            replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                text, label, candidates)

            with torch.no_grad():
                input_dict = self.model.encode_text(
                    list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)

            pred, true_labels = models.utils.make_rel_prediction(
                output, ranking_label)
            all_predictions.extend(pred.tolist())
            all_labels.extend(true_labels.tolist())

        acc = models.utils.calculate_accuracy(all_predictions, all_labels)

        return acc

    def training(self, train_datasets, **kwargs):
        n_epochs = kwargs.get('n_epochs', 1)
        log_freq = kwargs.get('log_freq', 20)
        mini_batch_size = kwargs.get('mini_batch_size')
        train_dataset = data.ConcatDataset(train_datasets)
        train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=mini_batch_size,
            shuffle=False,
            collate_fn=datasets.utils.rel_encode)
        self.train(dataloader=train_dataloader,
                   n_epochs=n_epochs,
                   log_freq=log_freq)

    def testing(self, test_dataset, **kwargs):
        mini_batch_size = kwargs.get('mini_batch_size')
        test_dataloader = data.DataLoader(test_dataset,
                                          batch_size=mini_batch_size,
                                          shuffle=False,
                                          collate_fn=datasets.utils.rel_encode)
        acc = self.evaluate(dataloader=test_dataloader)
        logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc))
        return acc
Exemple #2
0
class Replay:
    def __init__(self, device, n_classes, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=n_classes,
                                         max_length=kwargs.get('max_length'),
                                         device=device,
                                         hebbian=kwargs.get('hebbian'))
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)
        logger.info('Loaded {} as model'.format(self.model.__class__.__name__))

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.lr)

    def save_model(self, model_path):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint)

    def train(self, dataloader, n_epochs, log_freq):

        self.model.train()

        for epoch in range(n_epochs):
            all_losses, all_predictions, all_labels = [], [], []
            iter = 0

            for text, labels in dataloader:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.model.encode_text(text)
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                mini_batch_size = len(labels)
                replay_freq = self.replay_every // mini_batch_size
                replay_steps = int(self.replay_every * self.replay_rate /
                                   mini_batch_size)

                if self.replay_rate != 0 and (iter + 1) % replay_freq == 0:
                    self.optimizer.zero_grad()
                    for _ in range(replay_steps):
                        ref_text, ref_labels = self.memory.read_batch(
                            batch_size=mini_batch_size)
                        ref_labels = torch.tensor(ref_labels).to(self.device)
                        ref_input_dict = self.model.encode_text(ref_text)
                        ref_output = self.model(ref_input_dict)
                        ref_loss = self.loss_fn(ref_output, ref_labels)
                        ref_loss.backward()

                    params = [
                        p for p in self.model.parameters() if p.requires_grad
                    ]
                    torch.nn.utils.clip_grad_norm(params, 25)
                    self.optimizer.step()

                loss = loss.item()
                pred = models.utils.make_prediction(output.detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                iter += 1
                self.memory.write_batch(text, labels)

                if iter % log_freq == 0:
                    acc, prec, rec, f1 = models.utils.calculate_metrics(
                        all_predictions, all_labels)
                    logger.info(
                        'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
                        'F1 score = {:.4f}'.format(epoch + 1,
                                                   np.mean(all_losses), acc,
                                                   prec, rec, f1))
                    all_losses, all_predictions, all_labels = [], [], []

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for text, labels in dataloader:
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = models.utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())

        acc, prec, rec, f1 = models.utils.calculate_metrics(
            all_predictions, all_labels)
        logger.info(
            'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec,
                                       f1))

        return acc, prec, rec, f1

    def training(self, train_datasets, **kwargs):
        n_epochs = kwargs.get('n_epochs', 1)
        log_freq = kwargs.get('log_freq', 50)
        mini_batch_size = kwargs.get('mini_batch_size')
        train_dataset = data.ConcatDataset(train_datasets)
        train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=mini_batch_size,
            shuffle=False,
            collate_fn=datasets.utils.batch_encode)
        self.train(dataloader=train_dataloader,
                   n_epochs=n_epochs,
                   log_freq=log_freq)

    def testing(self, test_datasets, **kwargs):
        mini_batch_size = kwargs.get('mini_batch_size')
        accuracies, precisions, recalls, f1s = [], [], [], []
        for test_dataset in test_datasets:
            logger.info('Testing on {}'.format(
                test_dataset.__class__.__name__))
            test_dataloader = data.DataLoader(
                test_dataset,
                batch_size=mini_batch_size,
                shuffle=False,
                collate_fn=datasets.utils.batch_encode)
            acc, prec, rec, f1 = self.evaluate(dataloader=test_dataloader)
            accuracies.append(acc)
            precisions.append(prec)
            recalls.append(rec)
            f1s.append(f1)

        logger.info(
            'Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(accuracies),
                                       np.mean(precisions), np.mean(recalls),
                                       np.mean(f1s)))
Exemple #3
0
class Baseline:

    def __init__(self, device, training_mode, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.device = device
        self.training_mode = training_mode

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=1,
                                         max_length=kwargs.get('max_length'),
                                         device=device)
        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=self.lr)

        logger.info('Loaded {} as the model'.format(self.model.__class__.__name__))

        self.loss_fn = nn.BCEWithLogitsLoss()

    def save_model(self, model_path):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint)

    def train(self, dataloader, n_epochs, log_freq):

        self.model.train()

        for epoch in range(n_epochs):
            all_losses, all_predictions, all_labels = [], [], []
            iter = 0

            for text, label, candidates in dataloader:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label,
                                                                                                         candidates)

                input_dict = self.model.encode_text(list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)
                targets = torch.tensor(ranking_label).float().unsqueeze(1).to(self.device)
                loss = self.loss_fn(output, targets)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                loss = loss.item()
                pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(true_labels.tolist())
                iter += 1

                if iter % log_freq == 0:
                    acc = models.utils.calculate_accuracy(all_predictions, all_labels)
                    logger.info(
                        'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(epoch + 1, np.mean(all_losses), acc))
                    all_losses, all_predictions, all_labels = [], [], []

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for text, label, candidates in dataloader:
            replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label,
                                                                                                     candidates)

            with torch.no_grad():
                input_dict = self.model.encode_text(list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)

            pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
            all_predictions.extend(pred.tolist())
            all_labels.extend(true_labels.tolist())

        acc = models.utils.calculate_accuracy(all_predictions, all_labels)

        return acc

    def training(self, train_datasets, **kwargs):
        n_epochs = kwargs.get('n_epochs', 1)
        log_freq = kwargs.get('log_freq', 20)
        mini_batch_size = kwargs.get('mini_batch_size')
        if self.training_mode == 'sequential':
            for cluster_idx, train_dataset in enumerate(train_datasets):
                logger.info('Training on cluster {}'.format(cluster_idx + 1))
                train_dataloader = data.DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=False,
                                                   collate_fn=datasets.utils.rel_encode)
                self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq)
        elif self.training_mode == 'multi_task':
            train_dataset = data.ConcatDataset(train_datasets)
            logger.info('Training multi-task model on all datasets')
            train_dataloader = data.DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True,
                                               collate_fn=datasets.utils.rel_encode)
            self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq)

    def testing(self, test_dataset, **kwargs):
        mini_batch_size = kwargs.get('mini_batch_size')
        test_dataloader = data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False,
                                          collate_fn=datasets.utils.rel_encode)
        acc = self.evaluate(dataloader=test_dataloader)
        logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc))
        return acc
Exemple #4
0
class Baseline:
    def __init__(self, device, n_classes, training_mode, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.device = device
        self.training_mode = training_mode
        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=n_classes,
                                         max_length=kwargs.get('max_length'),
                                         device=device)
        logger.info('Loaded {} as model'.format(self.model.__class__.__name__))
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.lr)

    def save_model(self, model_path):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint)

    def train(self, dataloader, n_epochs, log_freq):

        self.model.train()

        for epoch in range(n_epochs):
            all_losses, all_predictions, all_labels = [], [], []
            iter = 0

            for text, labels in dataloader:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.model.encode_text(text)
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                loss = loss.item()
                pred = models.utils.make_prediction(output.detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                iter += 1

                if iter % log_freq == 0:
                    acc, prec, rec, f1 = models.utils.calculate_metrics(
                        all_predictions, all_labels)
                    logger.info(
                        'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
                        'F1 score = {:.4f}'.format(epoch + 1,
                                                   np.mean(all_losses), acc,
                                                   prec, rec, f1))
                    all_losses, all_predictions, all_labels = [], [], []

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for text, labels in dataloader:
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = models.utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())

        acc, prec, rec, f1 = models.utils.calculate_metrics(
            all_predictions, all_labels)
        logger.info(
            'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec,
                                       f1))

        return acc, prec, rec, f1

    def training(self, train_datasets, **kwargs):
        n_epochs = kwargs.get('n_epochs', 1)
        log_freq = kwargs.get('log_freq', 500)
        mini_batch_size = kwargs.get('mini_batch_size')
        if self.training_mode == 'sequential':
            for train_dataset in train_datasets:
                logger.info('Training on {}'.format(
                    train_dataset.__class__.__name__))
                train_dataloader = data.DataLoader(
                    train_dataset,
                    batch_size=mini_batch_size,
                    shuffle=False,
                    collate_fn=datasets.utils.batch_encode)
                self.train(dataloader=train_dataloader,
                           n_epochs=n_epochs,
                           log_freq=log_freq)
        elif self.training_mode == 'multi_task':
            train_dataset = data.ConcatDataset(train_datasets)
            logger.info('Training multi-task model on all datasets')
            train_dataloader = data.DataLoader(
                train_dataset,
                batch_size=mini_batch_size,
                shuffle=True,
                collate_fn=datasets.utils.batch_encode)
            self.train(dataloader=train_dataloader,
                       n_epochs=n_epochs,
                       log_freq=log_freq)
        else:
            raise ValueError('Invalid training mode')

    def testing(self, test_datasets, **kwargs):
        mini_batch_size = kwargs.get('mini_batch_size')
        accuracies, precisions, recalls, f1s = [], [], [], []
        for test_dataset in test_datasets:
            logger.info('Testing on {}'.format(
                test_dataset.__class__.__name__))
            test_dataloader = data.DataLoader(
                test_dataset,
                batch_size=mini_batch_size,
                shuffle=False,
                collate_fn=datasets.utils.batch_encode)
            acc, prec, rec, f1 = self.evaluate(dataloader=test_dataloader)
            accuracies.append(acc)
            precisions.append(prec)
            recalls.append(rec)
            f1s.append(f1)

        logger.info(
            'Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(accuracies),
                                       np.mean(precisions), np.mean(recalls),
                                       np.mean(f1s)))