Ejemplo n.º 1
0
    def __init__(self, device, n_classes, **kwargs):
        self.inner_lr = kwargs.get('inner_lr')
        self.meta_lr = kwargs.get('meta_lr')
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.nm = TransformerNeuromodulator(model_name=kwargs.get('model'),
                                            device=device)
        self.pn = TransformerClsModel(model_name=kwargs.get('model'),
                                      n_classes=n_classes,
                                      max_length=kwargs.get('max_length'),
                                      device=device)
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)
        self.loss_fn = nn.CrossEntropyLoss()

        logger.info('Loaded {} as NM'.format(self.nm.__class__.__name__))
        logger.info('Loaded {} as PN'.format(self.pn.__class__.__name__))

        meta_params = [p for p in self.nm.parameters() if p.requires_grad] + \
                      [p for p in self.pn.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        inner_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)
Ejemplo n.º 2
0
    def __init__(self, device, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=1,
                                         max_length=kwargs.get('max_length'),
                                         device=device)
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3)
        logger.info('Loaded {} as the model'.format(self.model.__class__.__name__))

        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=self.lr)
        self.loss_fn = nn.BCEWithLogitsLoss()
Ejemplo n.º 3
0
    def __init__(self, device, n_classes, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=n_classes,
                                         max_length=kwargs.get('max_length'),
                                         device=device,
                                         hebbian=kwargs.get('hebbian'))
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)
        logger.info('Loaded {} as model'.format(self.model.__class__.__name__))

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.lr)
Ejemplo n.º 4
0
    def __init__(self, device, **kwargs):
        self.inner_lr = kwargs.get('inner_lr')
        self.meta_lr = kwargs.get('meta_lr')
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.rln = TransformerRLN(model_name=kwargs.get('model'),
                                  max_length=kwargs.get('max_length'),
                                  device=device)
        self.pln = LinearPLN(in_dim=768, out_dim=1, device=device)
        meta_params = [p for p in self.rln.parameters() if p.requires_grad] + \
                      [p for p in self.pln.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3)
        self.loss_fn = nn.BCEWithLogitsLoss()

        logger.info('Loaded {} as RLN'.format(self.rln.__class__.__name__))
        logger.info('Loaded {} as PLN'.format(self.pln.__class__.__name__))

        inner_params = [p for p in self.pln.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)
Ejemplo n.º 5
0
    def __init__(self, device, **kwargs):
        self.inner_lr = kwargs.get('inner_lr')
        self.meta_lr = kwargs.get('meta_lr')
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.pn = TransformerClsModel(model_name=kwargs.get('model'),
                                      n_classes=1,
                                      max_length=kwargs.get('max_length'),
                                      device=device,
                                      hebbian=kwargs.get('hebbian'))

        logger.info('Loaded {} as PN'.format(self.pn.__class__.__name__))

        meta_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3)
        self.loss_fn = nn.BCEWithLogitsLoss()

        inner_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)
Ejemplo n.º 6
0
class AGEM:
    def __init__(self, device, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=1,
                                         max_length=kwargs.get('max_length'),
                                         device=device,
                                         hebbian=kwargs.get('hebbian'))
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3)
        logger.info('Loaded {} as the model'.format(
            self.model.__class__.__name__))

        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=self.lr)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def save_model(self, model_path):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint)

    def compute_grad(self, orig_grad, ref_grad):
        with torch.no_grad():
            flat_orig_grad = torch.cat([torch.flatten(x) for x in orig_grad])
            flat_ref_grad = torch.cat([torch.flatten(x) for x in ref_grad])
            dot_product = torch.dot(flat_orig_grad, flat_ref_grad)
            if dot_product >= 0:
                return orig_grad
            proj_component = dot_product / torch.dot(flat_ref_grad,
                                                     flat_ref_grad)
            modified_grad = [
                o - proj_component * r for (o, r) in zip(orig_grad, ref_grad)
            ]
            return modified_grad

    def train(self, dataloader, n_epochs, log_freq):

        self.model.train()

        for epoch in range(n_epochs):
            all_losses, all_predictions, all_labels = [], [], []
            iter = 0

            for text, label, candidates in dataloader:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                    text, label, candidates)

                input_dict = self.model.encode_text(
                    list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)
                targets = torch.tensor(ranking_label).float().unsqueeze(1).to(
                    self.device)
                loss = self.loss_fn(output, targets)
                self.optimizer.zero_grad()

                params = [
                    p for p in self.model.parameters() if p.requires_grad
                ]
                orig_grad = torch.autograd.grad(loss, params)

                mini_batch_size = len(label)
                replay_freq = self.replay_every // mini_batch_size
                replay_steps = int(self.replay_every * self.replay_rate /
                                   mini_batch_size)

                if self.replay_rate != 0 and (iter + 1) % replay_freq == 0:
                    ref_grad_sum = None
                    for _ in range(replay_steps):
                        ref_text, ref_label, ref_candidates = self.memory.read_batch(
                            batch_size=mini_batch_size)
                        replicated_ref_text, replicated_ref_relations, ref_ranking_label = datasets.utils.replicate_rel_data(
                            ref_text, ref_label, ref_candidates)
                        ref_input_dict = self.model.encode_text(
                            list(
                                zip(replicated_ref_text,
                                    replicated_ref_relations)))
                        ref_output = self.model(ref_input_dict)
                        ref_targets = torch.tensor(
                            ref_ranking_label).float().unsqueeze(1).to(
                                self.device)
                        ref_loss = self.loss_fn(ref_output, ref_targets)
                        ref_grad = torch.autograd.grad(ref_loss, params)
                        if ref_grad_sum is None:
                            ref_grad_sum = ref_grad
                        else:
                            ref_grad_sum = [
                                x + y
                                for (x, y) in zip(ref_grad, ref_grad_sum)
                            ]
                    final_grad = self.compute_grad(orig_grad, ref_grad_sum)
                else:
                    final_grad = orig_grad

                for param, grad in zip(params, final_grad):
                    param.grad = grad.data

                self.optimizer.step()

                loss = loss.item()
                pred, true_labels = models.utils.make_rel_prediction(
                    output, ranking_label)
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(true_labels.tolist())
                iter += 1
                self.memory.write_batch(text, label, candidates)

                if iter % log_freq == 0:
                    acc = models.utils.calculate_accuracy(
                        all_predictions, all_labels)
                    logger.info(
                        'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}'.
                        format(epoch + 1, np.mean(all_losses), acc))
                    all_losses, all_predictions, all_labels = [], [], []

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for text, label, candidates in dataloader:
            replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                text, label, candidates)

            with torch.no_grad():
                input_dict = self.model.encode_text(
                    list(zip(replicated_text, replicated_relations)))
                output = self.model(input_dict)

            pred, true_labels = models.utils.make_rel_prediction(
                output, ranking_label)
            all_predictions.extend(pred.tolist())
            all_labels.extend(true_labels.tolist())

        acc = models.utils.calculate_accuracy(all_predictions, all_labels)

        return acc

    def training(self, train_datasets, **kwargs):
        n_epochs = kwargs.get('n_epochs', 1)
        log_freq = kwargs.get('log_freq', 20)
        mini_batch_size = kwargs.get('mini_batch_size')
        train_dataset = data.ConcatDataset(train_datasets)
        train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=mini_batch_size,
            shuffle=False,
            collate_fn=datasets.utils.rel_encode)
        self.train(dataloader=train_dataloader,
                   n_epochs=n_epochs,
                   log_freq=log_freq)

    def testing(self, test_dataset, **kwargs):
        mini_batch_size = kwargs.get('mini_batch_size')
        test_dataloader = data.DataLoader(test_dataset,
                                          batch_size=mini_batch_size,
                                          shuffle=False,
                                          collate_fn=datasets.utils.rel_encode)
        acc = self.evaluate(dataloader=test_dataloader)
        logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc))
        return acc
Ejemplo n.º 7
0
class ANML:
    def __init__(self, device, n_classes, **kwargs):
        self.inner_lr = kwargs.get('inner_lr')
        self.meta_lr = kwargs.get('meta_lr')
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.nm = TransformerNeuromodulator(model_name=kwargs.get('model'),
                                            device=device)
        self.pn = TransformerClsModel(model_name=kwargs.get('model'),
                                      n_classes=n_classes,
                                      max_length=kwargs.get('max_length'),
                                      device=device)
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)
        self.loss_fn = nn.CrossEntropyLoss()

        logger.info('Loaded {} as NM'.format(self.nm.__class__.__name__))
        logger.info('Loaded {} as PN'.format(self.pn.__class__.__name__))

        meta_params = [p for p in self.nm.parameters() if p.requires_grad] + \
                      [p for p in self.pn.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        inner_params = [p for p in self.pn.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)

    def save_model(self, model_path):
        checkpoint = {'nm': self.nm.state_dict(), 'pn': self.pn.state_dict()}
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.nm.load_state_dict(checkpoint['nm'])
        self.pn.load_state_dict(checkpoint['pn'])

    def evaluate(self, dataloader, updates, mini_batch_size):

        support_set = []
        for _ in range(updates):
            text, labels = self.memory.read_batch(batch_size=mini_batch_size)
            support_set.append((text, labels))

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):

            # Inner loop
            task_predictions, task_labels = [], []
            support_loss = []
            for text, labels in support_set:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.pn.encode_text(text)
                repr = fpn(input_dict, out_from='transformers')
                modulation = self.nm(input_dict)
                output = fpn(repr * modulation, out_from='linear')
                loss = self.loss_fn(output, labels)
                diffopt.step(loss)
                pred = models.utils.make_prediction(output.detach())
                support_loss.append(loss.item())
                task_predictions.extend(pred.tolist())
                task_labels.extend(labels.tolist())

            acc, prec, rec, f1 = models.utils.calculate_metrics(
                task_predictions, task_labels)

            logger.info(
                'Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                'recall = {:.4f}, F1 score = {:.4f}'.format(
                    np.mean(support_loss), acc, prec, rec, f1))

            all_losses, all_predictions, all_labels = [], [], []

            for text, labels in dataloader:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.pn.encode_text(text)
                with torch.no_grad():
                    repr = fpn(input_dict, out_from='transformers')
                    modulation = self.nm(input_dict)
                    output = fpn(repr * modulation, out_from='linear')
                    loss = self.loss_fn(output, labels)
                loss = loss.item()
                pred = models.utils.make_prediction(output.detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())

        acc, prec, rec, f1 = models.utils.calculate_metrics(
            all_predictions, all_labels)
        logger.info(
            'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec,
                                       f1))

        return acc, prec, rec, f1

    def training(self, train_datasets, **kwargs):
        updates = kwargs.get('updates')
        mini_batch_size = kwargs.get('mini_batch_size')

        if self.replay_rate != 0:
            replay_batch_freq = self.replay_every // mini_batch_size
            replay_freq = int(
                math.ceil((replay_batch_freq + 1) / (updates + 1)))
            replay_steps = int(self.replay_every * self.replay_rate /
                               mini_batch_size)
        else:
            replay_freq = 0
            replay_steps = 0
        logger.info('Replay frequency: {}'.format(replay_freq))
        logger.info('Replay steps: {}'.format(replay_steps))

        concat_dataset = data.ConcatDataset(train_datasets)
        train_dataloader = iter(
            data.DataLoader(concat_dataset,
                            batch_size=mini_batch_size,
                            shuffle=False,
                            collate_fn=datasets.utils.batch_encode))

        episode_id = 0
        while True:

            self.inner_optimizer.zero_grad()
            support_loss, support_acc, support_prec, support_rec, support_f1 = [], [], [], [], []

            with higher.innerloop_ctx(self.pn,
                                      self.inner_optimizer,
                                      copy_initial_weights=False,
                                      track_higher_grads=False) as (fpn,
                                                                    diffopt):

                # Inner loop
                support_set = []
                task_predictions, task_labels = [], []
                for _ in range(updates):
                    try:
                        text, labels = next(train_dataloader)
                        support_set.append((text, labels))
                    except StopIteration:
                        logger.info(
                            'Terminating training as all the data is seen')
                        return

                for text, labels in support_set:
                    labels = torch.tensor(labels).to(self.device)
                    input_dict = self.pn.encode_text(text)
                    repr = fpn(input_dict, out_from='transformers')
                    modulation = self.nm(input_dict)
                    output = fpn(repr * modulation, out_from='linear')
                    loss = self.loss_fn(output, labels)
                    diffopt.step(loss)
                    pred = models.utils.make_prediction(output.detach())
                    support_loss.append(loss.item())
                    task_predictions.extend(pred.tolist())
                    task_labels.extend(labels.tolist())
                    self.memory.write_batch(text, labels)

                acc, prec, rec, f1 = models.utils.calculate_metrics(
                    task_predictions, task_labels)

                logger.info(
                    'Episode {} support set: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                    'recall = {:.4f}, F1 score = {:.4f}'.format(
                        episode_id + 1, np.mean(support_loss), acc, prec, rec,
                        f1))

                # Outer loop
                query_loss, query_acc, query_prec, query_rec, query_f1 = [], [], [], [], []
                query_set = []

                if self.replay_rate != 0 and (episode_id +
                                              1) % replay_freq == 0:
                    for _ in range(replay_steps):
                        text, labels = self.memory.read_batch(
                            batch_size=mini_batch_size)
                        query_set.append((text, labels))
                else:
                    try:
                        text, labels = next(train_dataloader)
                        query_set.append((text, labels))
                        self.memory.write_batch(text, labels)
                    except StopIteration:
                        logger.info(
                            'Terminating training as all the data is seen')
                        return

                for text, labels in query_set:
                    labels = torch.tensor(labels).to(self.device)
                    input_dict = self.pn.encode_text(text)
                    repr = fpn(input_dict, out_from='transformers')
                    modulation = self.nm(input_dict)
                    output = fpn(repr * modulation, out_from='linear')
                    loss = self.loss_fn(output, labels)
                    query_loss.append(loss.item())
                    pred = models.utils.make_prediction(output.detach())

                    acc, prec, rec, f1 = models.utils.calculate_metrics(
                        pred.tolist(), labels.tolist())
                    query_acc.append(acc)
                    query_prec.append(prec)
                    query_rec.append(rec)
                    query_f1.append(f1)

                    # NM meta gradients
                    nm_params = [
                        p for p in self.nm.parameters() if p.requires_grad
                    ]
                    meta_nm_grads = torch.autograd.grad(loss,
                                                        nm_params,
                                                        retain_graph=True)
                    for param, meta_grad in zip(nm_params, meta_nm_grads):
                        if param.grad is not None:
                            param.grad += meta_grad.detach()
                        else:
                            param.grad = meta_grad.detach()

                    # PN meta gradients
                    pn_params = [
                        p for p in fpn.parameters() if p.requires_grad
                    ]
                    meta_pn_grads = torch.autograd.grad(loss, pn_params)
                    pn_params = [
                        p for p in self.pn.parameters() if p.requires_grad
                    ]
                    for param, meta_grad in zip(pn_params, meta_pn_grads):
                        if param.grad is not None:
                            param.grad += meta_grad.detach()
                        else:
                            param.grad = meta_grad.detach()

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()

                logger.info(
                    'Episode {} query set: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                    'recall = {:.4f}, F1 score = {:.4f}'.format(
                        episode_id + 1, np.mean(query_loss),
                        np.mean(query_acc), np.mean(query_prec),
                        np.mean(query_rec), np.mean(query_f1)))

                episode_id += 1

    def testing(self, test_datasets, **kwargs):
        updates = kwargs.get('updates')
        mini_batch_size = kwargs.get('mini_batch_size')
        accuracies, precisions, recalls, f1s = [], [], [], []
        for test_dataset in test_datasets:
            logger.info('Testing on {}'.format(
                test_dataset.__class__.__name__))
            test_dataloader = data.DataLoader(
                test_dataset,
                batch_size=mini_batch_size,
                shuffle=False,
                collate_fn=datasets.utils.batch_encode)
            acc, prec, rec, f1 = self.evaluate(dataloader=test_dataloader,
                                               updates=updates,
                                               mini_batch_size=mini_batch_size)
            accuracies.append(acc)
            precisions.append(prec)
            recalls.append(rec)
            f1s.append(f1)

        logger.info(
            'Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(accuracies),
                                       np.mean(precisions), np.mean(recalls),
                                       np.mean(f1s)))

        return accuracies
Ejemplo n.º 8
0
class Replay:
    def __init__(self, device, n_classes, **kwargs):
        self.lr = kwargs.get('lr', 3e-5)
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.model = TransformerClsModel(model_name=kwargs.get('model'),
                                         n_classes=n_classes,
                                         max_length=kwargs.get('max_length'),
                                         device=device,
                                         hebbian=kwargs.get('hebbian'))
        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2)
        logger.info('Loaded {} as model'.format(self.model.__class__.__name__))

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.lr)

    def save_model(self, model_path):
        checkpoint = self.model.state_dict()
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint)

    def train(self, dataloader, n_epochs, log_freq):

        self.model.train()

        for epoch in range(n_epochs):
            all_losses, all_predictions, all_labels = [], [], []
            iter = 0

            for text, labels in dataloader:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.model.encode_text(text)
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                mini_batch_size = len(labels)
                replay_freq = self.replay_every // mini_batch_size
                replay_steps = int(self.replay_every * self.replay_rate /
                                   mini_batch_size)

                if self.replay_rate != 0 and (iter + 1) % replay_freq == 0:
                    self.optimizer.zero_grad()
                    for _ in range(replay_steps):
                        ref_text, ref_labels = self.memory.read_batch(
                            batch_size=mini_batch_size)
                        ref_labels = torch.tensor(ref_labels).to(self.device)
                        ref_input_dict = self.model.encode_text(ref_text)
                        ref_output = self.model(ref_input_dict)
                        ref_loss = self.loss_fn(ref_output, ref_labels)
                        ref_loss.backward()

                    params = [
                        p for p in self.model.parameters() if p.requires_grad
                    ]
                    torch.nn.utils.clip_grad_norm(params, 25)
                    self.optimizer.step()

                loss = loss.item()
                pred = models.utils.make_prediction(output.detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())
                iter += 1
                self.memory.write_batch(text, labels)

                if iter % log_freq == 0:
                    acc, prec, rec, f1 = models.utils.calculate_metrics(
                        all_predictions, all_labels)
                    logger.info(
                        'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
                        'F1 score = {:.4f}'.format(epoch + 1,
                                                   np.mean(all_losses), acc,
                                                   prec, rec, f1))
                    all_losses, all_predictions, all_labels = [], [], []

    def evaluate(self, dataloader):
        all_losses, all_predictions, all_labels = [], [], []

        self.model.eval()

        for text, labels in dataloader:
            labels = torch.tensor(labels).to(self.device)
            input_dict = self.model.encode_text(text)
            with torch.no_grad():
                output = self.model(input_dict)
                loss = self.loss_fn(output, labels)
            loss = loss.item()
            pred = models.utils.make_prediction(output.detach())
            all_losses.append(loss)
            all_predictions.extend(pred.tolist())
            all_labels.extend(labels.tolist())

        acc, prec, rec, f1 = models.utils.calculate_metrics(
            all_predictions, all_labels)
        logger.info(
            'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec,
                                       f1))

        return acc, prec, rec, f1

    def training(self, train_datasets, **kwargs):
        n_epochs = kwargs.get('n_epochs', 1)
        log_freq = kwargs.get('log_freq', 50)
        mini_batch_size = kwargs.get('mini_batch_size')
        train_dataset = data.ConcatDataset(train_datasets)
        train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=mini_batch_size,
            shuffle=False,
            collate_fn=datasets.utils.batch_encode)
        self.train(dataloader=train_dataloader,
                   n_epochs=n_epochs,
                   log_freq=log_freq)

    def testing(self, test_datasets, **kwargs):
        mini_batch_size = kwargs.get('mini_batch_size')
        accuracies, precisions, recalls, f1s = [], [], [], []
        for test_dataset in test_datasets:
            logger.info('Testing on {}'.format(
                test_dataset.__class__.__name__))
            test_dataloader = data.DataLoader(
                test_dataset,
                batch_size=mini_batch_size,
                shuffle=False,
                collate_fn=datasets.utils.batch_encode)
            acc, prec, rec, f1 = self.evaluate(dataloader=test_dataloader)
            accuracies.append(acc)
            precisions.append(prec)
            recalls.append(rec)
            f1s.append(f1)

        logger.info(
            'Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(accuracies),
                                       np.mean(precisions), np.mean(recalls),
                                       np.mean(f1s)))
Ejemplo n.º 9
0
class OML:
    def __init__(self, device, **kwargs):
        self.inner_lr = kwargs.get('inner_lr')
        self.meta_lr = kwargs.get('meta_lr')
        self.write_prob = kwargs.get('write_prob')
        self.replay_rate = kwargs.get('replay_rate')
        self.replay_every = kwargs.get('replay_every')
        self.device = device

        self.rln = TransformerRLN(model_name=kwargs.get('model'),
                                  max_length=kwargs.get('max_length'),
                                  device=device)
        self.pln = LinearPLN(in_dim=768, out_dim=1, device=device)
        meta_params = [p for p in self.rln.parameters() if p.requires_grad] + \
                      [p for p in self.pln.parameters() if p.requires_grad]
        self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr)

        self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3)
        self.loss_fn = nn.BCEWithLogitsLoss()

        logger.info('Loaded {} as RLN'.format(self.rln.__class__.__name__))
        logger.info('Loaded {} as PLN'.format(self.pln.__class__.__name__))

        inner_params = [p for p in self.pln.parameters() if p.requires_grad]
        self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr)

    def save_model(self, model_path):
        checkpoint = {
            'rln': self.rln.state_dict(),
            'pln': self.pln.state_dict()
        }
        torch.save(checkpoint, model_path)

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.rln.load_state_dict(checkpoint['rln'])
        self.pln.load_state_dict(checkpoint['pln'])

    def evaluate(self, dataloader, updates, mini_batch_size):

        self.rln.eval()
        self.pln.train()

        support_set = []
        for _ in range(updates):
            text, label, candidates = self.memory.read_batch(
                batch_size=mini_batch_size)
            support_set.append((text, label, candidates))

        with higher.innerloop_ctx(self.pln,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpln, diffopt):

            # Inner loop
            task_predictions, task_labels = [], []
            support_loss = []
            for text, label, candidates in support_set:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                    text, label, candidates)

                input_dict = self.rln.encode_text(
                    list(zip(replicated_text, replicated_relations)))
                repr = self.rln(input_dict)
                output = fpln(repr)
                targets = torch.tensor(ranking_label).float().unsqueeze(1).to(
                    self.device)
                loss = self.loss_fn(output, targets)

                diffopt.step(loss)
                pred, true_labels = models.utils.make_rel_prediction(
                    output, ranking_label)
                support_loss.append(loss.item())
                task_predictions.extend(pred.tolist())
                task_labels.extend(true_labels.tolist())

            acc = models.utils.calculate_accuracy(task_predictions,
                                                  task_labels)

            logger.info(
                'Support set metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(
                    np.mean(support_loss), acc))

            all_losses, all_predictions, all_labels = [], [], []

            for text, label, candidates in dataloader:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                    text, label, candidates)
                with torch.no_grad():

                    input_dict = self.rln.encode_text(
                        list(zip(replicated_text, replicated_relations)))
                    repr = self.rln(input_dict)
                    output = fpln(repr)
                    targets = torch.tensor(ranking_label).float().unsqueeze(
                        1).to(self.device)
                    loss = self.loss_fn(output, targets)

                loss = loss.item()
                pred, true_labels = models.utils.make_rel_prediction(
                    output, ranking_label)
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(true_labels.tolist())

        acc = models.utils.calculate_accuracy(all_predictions, all_labels)
        logger.info('Test metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(
            np.mean(all_losses), acc))

        return acc

    def training(self, train_datasets, **kwargs):
        updates = kwargs.get('updates')
        mini_batch_size = kwargs.get('mini_batch_size')

        if self.replay_rate != 0:
            replay_batch_freq = self.replay_every // mini_batch_size
            replay_freq = int(
                math.ceil((replay_batch_freq + 1) / (updates + 1)))
            replay_steps = int(self.replay_every * self.replay_rate /
                               mini_batch_size)
        else:
            replay_freq = 0
            replay_steps = 0
        logger.info('Replay frequency: {}'.format(replay_freq))
        logger.info('Replay steps: {}'.format(replay_steps))

        concat_dataset = data.ConcatDataset(train_datasets)
        train_dataloader = iter(
            data.DataLoader(concat_dataset,
                            batch_size=mini_batch_size,
                            shuffle=False,
                            collate_fn=datasets.utils.rel_encode))

        episode_id = 0
        while True:

            self.inner_optimizer.zero_grad()
            support_loss, support_acc = [], []

            with higher.innerloop_ctx(self.pln,
                                      self.inner_optimizer,
                                      copy_initial_weights=False,
                                      track_higher_grads=False) as (fpln,
                                                                    diffopt):

                # Inner loop
                support_set = []
                task_predictions, task_labels = [], []
                for _ in range(updates):
                    try:
                        text, label, candidates = next(train_dataloader)
                        support_set.append((text, label, candidates))
                    except StopIteration:
                        logger.info(
                            'Terminating training as all the data is seen')
                        return

                for text, label, candidates in support_set:
                    replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                        text, label, candidates)

                    input_dict = self.rln.encode_text(
                        list(zip(replicated_text, replicated_relations)))
                    repr = self.rln(input_dict)
                    output = fpln(repr)
                    targets = torch.tensor(ranking_label).float().unsqueeze(
                        1).to(self.device)
                    loss = self.loss_fn(output, targets)

                    diffopt.step(loss)
                    pred, true_labels = models.utils.make_rel_prediction(
                        output, ranking_label)
                    support_loss.append(loss.item())
                    task_predictions.extend(pred.tolist())
                    task_labels.extend(true_labels.tolist())
                    self.memory.write_batch(text, label, candidates)

                acc = models.utils.calculate_accuracy(task_predictions,
                                                      task_labels)

                logger.info(
                    'Episode {} support set: Loss = {:.4f}, accuracy = {:.4f}'.
                    format(episode_id + 1, np.mean(support_loss), acc))

                # Outer loop
                query_loss, query_acc = [], []
                query_set = []

                if self.replay_rate != 0 and (episode_id +
                                              1) % replay_freq == 0:
                    for _ in range(replay_steps):
                        text, label, candidates = self.memory.read_batch(
                            batch_size=mini_batch_size)
                        query_set.append((text, label, candidates))
                else:
                    try:
                        text, label, candidates = next(train_dataloader)
                        query_set.append((text, label, candidates))
                        self.memory.write_batch(text, label, candidates)
                    except StopIteration:
                        logger.info(
                            'Terminating training as all the data is seen')
                        return

                for text, label, candidates in query_set:
                    replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(
                        text, label, candidates)

                    input_dict = self.rln.encode_text(
                        list(zip(replicated_text, replicated_relations)))
                    repr = self.rln(input_dict)
                    output = fpln(repr)
                    targets = torch.tensor(ranking_label).float().unsqueeze(
                        1).to(self.device)
                    loss = self.loss_fn(output, targets)

                    query_loss.append(loss.item())
                    pred, true_labels = models.utils.make_rel_prediction(
                        output, ranking_label)

                    acc = models.utils.calculate_accuracy(
                        pred.tolist(), true_labels.tolist())
                    query_acc.append(acc)

                    # RLN meta gradients
                    rln_params = [
                        p for p in self.rln.parameters() if p.requires_grad
                    ]
                    meta_rln_grads = torch.autograd.grad(loss,
                                                         rln_params,
                                                         retain_graph=True)
                    for param, meta_grad in zip(rln_params, meta_rln_grads):
                        if param.grad is not None:
                            param.grad += meta_grad.detach()
                        else:
                            param.grad = meta_grad.detach()

                    # PLN meta gradients
                    pln_params = [
                        p for p in fpln.parameters() if p.requires_grad
                    ]
                    meta_pln_grads = torch.autograd.grad(loss, pln_params)
                    pln_params = [
                        p for p in self.pln.parameters() if p.requires_grad
                    ]
                    for param, meta_grad in zip(pln_params, meta_pln_grads):
                        if param.grad is not None:
                            param.grad += meta_grad.detach()
                        else:
                            param.grad = meta_grad.detach()

                # Meta optimizer step
                self.meta_optimizer.step()
                self.meta_optimizer.zero_grad()

                logger.info(
                    'Episode {} query set: Loss = {:.4f}, accuracy = {:.4f}'.
                    format(episode_id + 1, np.mean(query_loss),
                           np.mean(query_acc)))

                episode_id += 1

    def testing(self, test_dataset, **kwargs):
        updates = kwargs.get('updates')
        mini_batch_size = kwargs.get('mini_batch_size')
        test_dataloader = data.DataLoader(test_dataset,
                                          batch_size=mini_batch_size,
                                          shuffle=False,
                                          collate_fn=datasets.utils.rel_encode)
        acc = self.evaluate(dataloader=test_dataloader,
                            updates=updates,
                            mini_batch_size=mini_batch_size)
        logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc))
        return acc