class AGEM: def __init__(self, device, **kwargs): self.lr = kwargs.get('lr', 3e-5) self.write_prob = kwargs.get('write_prob') self.replay_rate = kwargs.get('replay_rate') self.replay_every = kwargs.get('replay_every') self.device = device self.model = TransformerClsModel(model_name=kwargs.get('model'), n_classes=1, max_length=kwargs.get('max_length'), device=device, hebbian=kwargs.get('hebbian')) self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3) logger.info('Loaded {} as the model'.format( self.model.__class__.__name__)) params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = AdamW(params, lr=self.lr) self.loss_fn = nn.BCEWithLogitsLoss() def save_model(self, model_path): checkpoint = self.model.state_dict() torch.save(checkpoint, model_path) def load_model(self, model_path): checkpoint = torch.load(model_path) self.model.load_state_dict(checkpoint) def compute_grad(self, orig_grad, ref_grad): with torch.no_grad(): flat_orig_grad = torch.cat([torch.flatten(x) for x in orig_grad]) flat_ref_grad = torch.cat([torch.flatten(x) for x in ref_grad]) dot_product = torch.dot(flat_orig_grad, flat_ref_grad) if dot_product >= 0: return orig_grad proj_component = dot_product / torch.dot(flat_ref_grad, flat_ref_grad) modified_grad = [ o - proj_component * r for (o, r) in zip(orig_grad, ref_grad) ] return modified_grad def train(self, dataloader, n_epochs, log_freq): self.model.train() for epoch in range(n_epochs): all_losses, all_predictions, all_labels = [], [], [] iter = 0 for text, label, candidates in dataloader: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data( text, label, candidates) input_dict = self.model.encode_text( list(zip(replicated_text, replicated_relations))) output = self.model(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze(1).to( self.device) loss = self.loss_fn(output, targets) self.optimizer.zero_grad() params = [ p for p in self.model.parameters() if p.requires_grad ] orig_grad = torch.autograd.grad(loss, params) mini_batch_size = len(label) replay_freq = self.replay_every // mini_batch_size replay_steps = int(self.replay_every * self.replay_rate / mini_batch_size) if self.replay_rate != 0 and (iter + 1) % replay_freq == 0: ref_grad_sum = None for _ in range(replay_steps): ref_text, ref_label, ref_candidates = self.memory.read_batch( batch_size=mini_batch_size) replicated_ref_text, replicated_ref_relations, ref_ranking_label = datasets.utils.replicate_rel_data( ref_text, ref_label, ref_candidates) ref_input_dict = self.model.encode_text( list( zip(replicated_ref_text, replicated_ref_relations))) ref_output = self.model(ref_input_dict) ref_targets = torch.tensor( ref_ranking_label).float().unsqueeze(1).to( self.device) ref_loss = self.loss_fn(ref_output, ref_targets) ref_grad = torch.autograd.grad(ref_loss, params) if ref_grad_sum is None: ref_grad_sum = ref_grad else: ref_grad_sum = [ x + y for (x, y) in zip(ref_grad, ref_grad_sum) ] final_grad = self.compute_grad(orig_grad, ref_grad_sum) else: final_grad = orig_grad for param, grad in zip(params, final_grad): param.grad = grad.data self.optimizer.step() loss = loss.item() pred, true_labels = models.utils.make_rel_prediction( output, ranking_label) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(true_labels.tolist()) iter += 1 self.memory.write_batch(text, label, candidates) if iter % log_freq == 0: acc = models.utils.calculate_accuracy( all_predictions, all_labels) logger.info( 'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}'. format(epoch + 1, np.mean(all_losses), acc)) all_losses, all_predictions, all_labels = [], [], [] def evaluate(self, dataloader): all_losses, all_predictions, all_labels = [], [], [] self.model.eval() for text, label, candidates in dataloader: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data( text, label, candidates) with torch.no_grad(): input_dict = self.model.encode_text( list(zip(replicated_text, replicated_relations))) output = self.model(input_dict) pred, true_labels = models.utils.make_rel_prediction( output, ranking_label) all_predictions.extend(pred.tolist()) all_labels.extend(true_labels.tolist()) acc = models.utils.calculate_accuracy(all_predictions, all_labels) return acc def training(self, train_datasets, **kwargs): n_epochs = kwargs.get('n_epochs', 1) log_freq = kwargs.get('log_freq', 20) mini_batch_size = kwargs.get('mini_batch_size') train_dataset = data.ConcatDataset(train_datasets) train_dataloader = data.DataLoader( train_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.rel_encode) self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq) def testing(self, test_dataset, **kwargs): mini_batch_size = kwargs.get('mini_batch_size') test_dataloader = data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.rel_encode) acc = self.evaluate(dataloader=test_dataloader) logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc)) return acc
class Baseline: def __init__(self, device, training_mode, **kwargs): self.lr = kwargs.get('lr', 3e-5) self.device = device self.training_mode = training_mode self.model = TransformerClsModel(model_name=kwargs.get('model'), n_classes=1, max_length=kwargs.get('max_length'), device=device) params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = AdamW(params, lr=self.lr) logger.info('Loaded {} as the model'.format(self.model.__class__.__name__)) self.loss_fn = nn.BCEWithLogitsLoss() def save_model(self, model_path): checkpoint = self.model.state_dict() torch.save(checkpoint, model_path) def load_model(self, model_path): checkpoint = torch.load(model_path) self.model.load_state_dict(checkpoint) def train(self, dataloader, n_epochs, log_freq): self.model.train() for epoch in range(n_epochs): all_losses, all_predictions, all_labels = [], [], [] iter = 0 for text, label, candidates in dataloader: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label, candidates) input_dict = self.model.encode_text(list(zip(replicated_text, replicated_relations))) output = self.model(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze(1).to(self.device) loss = self.loss_fn(output, targets) self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss = loss.item() pred, true_labels = models.utils.make_rel_prediction(output, ranking_label) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(true_labels.tolist()) iter += 1 if iter % log_freq == 0: acc = models.utils.calculate_accuracy(all_predictions, all_labels) logger.info( 'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(epoch + 1, np.mean(all_losses), acc)) all_losses, all_predictions, all_labels = [], [], [] def evaluate(self, dataloader): all_losses, all_predictions, all_labels = [], [], [] self.model.eval() for text, label, candidates in dataloader: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text, label, candidates) with torch.no_grad(): input_dict = self.model.encode_text(list(zip(replicated_text, replicated_relations))) output = self.model(input_dict) pred, true_labels = models.utils.make_rel_prediction(output, ranking_label) all_predictions.extend(pred.tolist()) all_labels.extend(true_labels.tolist()) acc = models.utils.calculate_accuracy(all_predictions, all_labels) return acc def training(self, train_datasets, **kwargs): n_epochs = kwargs.get('n_epochs', 1) log_freq = kwargs.get('log_freq', 20) mini_batch_size = kwargs.get('mini_batch_size') if self.training_mode == 'sequential': for cluster_idx, train_dataset in enumerate(train_datasets): logger.info('Training on cluster {}'.format(cluster_idx + 1)) train_dataloader = data.DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.rel_encode) self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq) elif self.training_mode == 'multi_task': train_dataset = data.ConcatDataset(train_datasets) logger.info('Training multi-task model on all datasets') train_dataloader = data.DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, collate_fn=datasets.utils.rel_encode) self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq) def testing(self, test_dataset, **kwargs): mini_batch_size = kwargs.get('mini_batch_size') test_dataloader = data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.rel_encode) acc = self.evaluate(dataloader=test_dataloader) logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc)) return acc
class ANML: def __init__(self, device, n_classes, **kwargs): self.inner_lr = kwargs.get('inner_lr') self.meta_lr = kwargs.get('meta_lr') self.write_prob = kwargs.get('write_prob') self.replay_rate = kwargs.get('replay_rate') self.replay_every = kwargs.get('replay_every') self.device = device self.nm = TransformerNeuromodulator(model_name=kwargs.get('model'), device=device) self.pn = TransformerClsModel(model_name=kwargs.get('model'), n_classes=n_classes, max_length=kwargs.get('max_length'), device=device) self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2) self.loss_fn = nn.CrossEntropyLoss() logger.info('Loaded {} as NM'.format(self.nm.__class__.__name__)) logger.info('Loaded {} as PN'.format(self.pn.__class__.__name__)) meta_params = [p for p in self.nm.parameters() if p.requires_grad] + \ [p for p in self.pn.parameters() if p.requires_grad] self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr) inner_params = [p for p in self.pn.parameters() if p.requires_grad] self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr) def save_model(self, model_path): checkpoint = {'nm': self.nm.state_dict(), 'pn': self.pn.state_dict()} torch.save(checkpoint, model_path) def load_model(self, model_path): checkpoint = torch.load(model_path) self.nm.load_state_dict(checkpoint['nm']) self.pn.load_state_dict(checkpoint['pn']) def evaluate(self, dataloader, updates, mini_batch_size): support_set = [] for _ in range(updates): text, labels = self.memory.read_batch(batch_size=mini_batch_size) support_set.append((text, labels)) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop task_predictions, task_labels = [], [] support_loss = [] for text, labels in support_set: labels = torch.tensor(labels).to(self.device) input_dict = self.pn.encode_text(text) repr = fpn(input_dict, out_from='transformers') modulation = self.nm(input_dict) output = fpn(repr * modulation, out_from='linear') loss = self.loss_fn(output, labels) diffopt.step(loss) pred = models.utils.make_prediction(output.detach()) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(labels.tolist()) acc, prec, rec, f1 = models.utils.calculate_metrics( task_predictions, task_labels) logger.info( 'Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, ' 'recall = {:.4f}, F1 score = {:.4f}'.format( np.mean(support_loss), acc, prec, rec, f1)) all_losses, all_predictions, all_labels = [], [], [] for text, labels in dataloader: labels = torch.tensor(labels).to(self.device) input_dict = self.pn.encode_text(text) with torch.no_grad(): repr = fpn(input_dict, out_from='transformers') modulation = self.nm(input_dict) output = fpn(repr * modulation, out_from='linear') loss = self.loss_fn(output, labels) loss = loss.item() pred = models.utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) acc, prec, rec, f1 = models.utils.calculate_metrics( all_predictions, all_labels) logger.info( 'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec, f1)) return acc, prec, rec, f1 def training(self, train_datasets, **kwargs): updates = kwargs.get('updates') mini_batch_size = kwargs.get('mini_batch_size') if self.replay_rate != 0: replay_batch_freq = self.replay_every // mini_batch_size replay_freq = int( math.ceil((replay_batch_freq + 1) / (updates + 1))) replay_steps = int(self.replay_every * self.replay_rate / mini_batch_size) else: replay_freq = 0 replay_steps = 0 logger.info('Replay frequency: {}'.format(replay_freq)) logger.info('Replay steps: {}'.format(replay_steps)) concat_dataset = data.ConcatDataset(train_datasets) train_dataloader = iter( data.DataLoader(concat_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.batch_encode)) episode_id = 0 while True: self.inner_optimizer.zero_grad() support_loss, support_acc, support_prec, support_rec, support_f1 = [], [], [], [], [] with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop support_set = [] task_predictions, task_labels = [], [] for _ in range(updates): try: text, labels = next(train_dataloader) support_set.append((text, labels)) except StopIteration: logger.info( 'Terminating training as all the data is seen') return for text, labels in support_set: labels = torch.tensor(labels).to(self.device) input_dict = self.pn.encode_text(text) repr = fpn(input_dict, out_from='transformers') modulation = self.nm(input_dict) output = fpn(repr * modulation, out_from='linear') loss = self.loss_fn(output, labels) diffopt.step(loss) pred = models.utils.make_prediction(output.detach()) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(labels.tolist()) self.memory.write_batch(text, labels) acc, prec, rec, f1 = models.utils.calculate_metrics( task_predictions, task_labels) logger.info( 'Episode {} support set: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, ' 'recall = {:.4f}, F1 score = {:.4f}'.format( episode_id + 1, np.mean(support_loss), acc, prec, rec, f1)) # Outer loop query_loss, query_acc, query_prec, query_rec, query_f1 = [], [], [], [], [] query_set = [] if self.replay_rate != 0 and (episode_id + 1) % replay_freq == 0: for _ in range(replay_steps): text, labels = self.memory.read_batch( batch_size=mini_batch_size) query_set.append((text, labels)) else: try: text, labels = next(train_dataloader) query_set.append((text, labels)) self.memory.write_batch(text, labels) except StopIteration: logger.info( 'Terminating training as all the data is seen') return for text, labels in query_set: labels = torch.tensor(labels).to(self.device) input_dict = self.pn.encode_text(text) repr = fpn(input_dict, out_from='transformers') modulation = self.nm(input_dict) output = fpn(repr * modulation, out_from='linear') loss = self.loss_fn(output, labels) query_loss.append(loss.item()) pred = models.utils.make_prediction(output.detach()) acc, prec, rec, f1 = models.utils.calculate_metrics( pred.tolist(), labels.tolist()) query_acc.append(acc) query_prec.append(prec) query_rec.append(rec) query_f1.append(f1) # NM meta gradients nm_params = [ p for p in self.nm.parameters() if p.requires_grad ] meta_nm_grads = torch.autograd.grad(loss, nm_params, retain_graph=True) for param, meta_grad in zip(nm_params, meta_nm_grads): if param.grad is not None: param.grad += meta_grad.detach() else: param.grad = meta_grad.detach() # PN meta gradients pn_params = [ p for p in fpn.parameters() if p.requires_grad ] meta_pn_grads = torch.autograd.grad(loss, pn_params) pn_params = [ p for p in self.pn.parameters() if p.requires_grad ] for param, meta_grad in zip(pn_params, meta_pn_grads): if param.grad is not None: param.grad += meta_grad.detach() else: param.grad = meta_grad.detach() # Meta optimizer step self.meta_optimizer.step() self.meta_optimizer.zero_grad() logger.info( 'Episode {} query set: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, ' 'recall = {:.4f}, F1 score = {:.4f}'.format( episode_id + 1, np.mean(query_loss), np.mean(query_acc), np.mean(query_prec), np.mean(query_rec), np.mean(query_f1))) episode_id += 1 def testing(self, test_datasets, **kwargs): updates = kwargs.get('updates') mini_batch_size = kwargs.get('mini_batch_size') accuracies, precisions, recalls, f1s = [], [], [], [] for test_dataset in test_datasets: logger.info('Testing on {}'.format( test_dataset.__class__.__name__)) test_dataloader = data.DataLoader( test_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.batch_encode) acc, prec, rec, f1 = self.evaluate(dataloader=test_dataloader, updates=updates, mini_batch_size=mini_batch_size) accuracies.append(acc) precisions.append(prec) recalls.append(rec) f1s.append(f1) logger.info( 'Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(np.mean(accuracies), np.mean(precisions), np.mean(recalls), np.mean(f1s))) return accuracies
class Replay: def __init__(self, device, n_classes, **kwargs): self.lr = kwargs.get('lr', 3e-5) self.write_prob = kwargs.get('write_prob') self.replay_rate = kwargs.get('replay_rate') self.replay_every = kwargs.get('replay_every') self.device = device self.model = TransformerClsModel(model_name=kwargs.get('model'), n_classes=n_classes, max_length=kwargs.get('max_length'), device=device, hebbian=kwargs.get('hebbian')) self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=2) logger.info('Loaded {} as model'.format(self.model.__class__.__name__)) self.loss_fn = nn.CrossEntropyLoss() self.optimizer = AdamW( [p for p in self.model.parameters() if p.requires_grad], lr=self.lr) def save_model(self, model_path): checkpoint = self.model.state_dict() torch.save(checkpoint, model_path) def load_model(self, model_path): checkpoint = torch.load(model_path) self.model.load_state_dict(checkpoint) def train(self, dataloader, n_epochs, log_freq): self.model.train() for epoch in range(n_epochs): all_losses, all_predictions, all_labels = [], [], [] iter = 0 for text, labels in dataloader: labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) output = self.model(input_dict) loss = self.loss_fn(output, labels) self.optimizer.zero_grad() loss.backward() self.optimizer.step() mini_batch_size = len(labels) replay_freq = self.replay_every // mini_batch_size replay_steps = int(self.replay_every * self.replay_rate / mini_batch_size) if self.replay_rate != 0 and (iter + 1) % replay_freq == 0: self.optimizer.zero_grad() for _ in range(replay_steps): ref_text, ref_labels = self.memory.read_batch( batch_size=mini_batch_size) ref_labels = torch.tensor(ref_labels).to(self.device) ref_input_dict = self.model.encode_text(ref_text) ref_output = self.model(ref_input_dict) ref_loss = self.loss_fn(ref_output, ref_labels) ref_loss.backward() params = [ p for p in self.model.parameters() if p.requires_grad ] torch.nn.utils.clip_grad_norm(params, 25) self.optimizer.step() loss = loss.item() pred = models.utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) iter += 1 self.memory.write_batch(text, labels) if iter % log_freq == 0: acc, prec, rec, f1 = models.utils.calculate_metrics( all_predictions, all_labels) logger.info( 'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(epoch + 1, np.mean(all_losses), acc, prec, rec, f1)) all_losses, all_predictions, all_labels = [], [], [] def evaluate(self, dataloader): all_losses, all_predictions, all_labels = [], [], [] self.model.eval() for text, labels in dataloader: labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) with torch.no_grad(): output = self.model(input_dict) loss = self.loss_fn(output, labels) loss = loss.item() pred = models.utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) acc, prec, rec, f1 = models.utils.calculate_metrics( all_predictions, all_labels) logger.info( 'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec, f1)) return acc, prec, rec, f1 def training(self, train_datasets, **kwargs): n_epochs = kwargs.get('n_epochs', 1) log_freq = kwargs.get('log_freq', 50) mini_batch_size = kwargs.get('mini_batch_size') train_dataset = data.ConcatDataset(train_datasets) train_dataloader = data.DataLoader( train_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.batch_encode) self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq) def testing(self, test_datasets, **kwargs): mini_batch_size = kwargs.get('mini_batch_size') accuracies, precisions, recalls, f1s = [], [], [], [] for test_dataset in test_datasets: logger.info('Testing on {}'.format( test_dataset.__class__.__name__)) test_dataloader = data.DataLoader( test_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.batch_encode) acc, prec, rec, f1 = self.evaluate(dataloader=test_dataloader) accuracies.append(acc) precisions.append(prec) recalls.append(rec) f1s.append(f1) logger.info( 'Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(np.mean(accuracies), np.mean(precisions), np.mean(recalls), np.mean(f1s)))
class Baseline: def __init__(self, device, n_classes, training_mode, **kwargs): self.lr = kwargs.get('lr', 3e-5) self.device = device self.training_mode = training_mode self.model = TransformerClsModel(model_name=kwargs.get('model'), n_classes=n_classes, max_length=kwargs.get('max_length'), device=device) logger.info('Loaded {} as model'.format(self.model.__class__.__name__)) self.loss_fn = nn.CrossEntropyLoss() self.optimizer = AdamW( [p for p in self.model.parameters() if p.requires_grad], lr=self.lr) def save_model(self, model_path): checkpoint = self.model.state_dict() torch.save(checkpoint, model_path) def load_model(self, model_path): checkpoint = torch.load(model_path) self.model.load_state_dict(checkpoint) def train(self, dataloader, n_epochs, log_freq): self.model.train() for epoch in range(n_epochs): all_losses, all_predictions, all_labels = [], [], [] iter = 0 for text, labels in dataloader: labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) output = self.model(input_dict) loss = self.loss_fn(output, labels) self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss = loss.item() pred = models.utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) iter += 1 if iter % log_freq == 0: acc, prec, rec, f1 = models.utils.calculate_metrics( all_predictions, all_labels) logger.info( 'Epoch {} metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(epoch + 1, np.mean(all_losses), acc, prec, rec, f1)) all_losses, all_predictions, all_labels = [], [], [] def evaluate(self, dataloader): all_losses, all_predictions, all_labels = [], [], [] self.model.eval() for text, labels in dataloader: labels = torch.tensor(labels).to(self.device) input_dict = self.model.encode_text(text) with torch.no_grad(): output = self.model(input_dict) loss = self.loss_fn(output, labels) loss = loss.item() pred = models.utils.make_prediction(output.detach()) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(labels.tolist()) acc, prec, rec, f1 = models.utils.calculate_metrics( all_predictions, all_labels) logger.info( 'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec, f1)) return acc, prec, rec, f1 def training(self, train_datasets, **kwargs): n_epochs = kwargs.get('n_epochs', 1) log_freq = kwargs.get('log_freq', 500) mini_batch_size = kwargs.get('mini_batch_size') if self.training_mode == 'sequential': for train_dataset in train_datasets: logger.info('Training on {}'.format( train_dataset.__class__.__name__)) train_dataloader = data.DataLoader( train_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.batch_encode) self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq) elif self.training_mode == 'multi_task': train_dataset = data.ConcatDataset(train_datasets) logger.info('Training multi-task model on all datasets') train_dataloader = data.DataLoader( train_dataset, batch_size=mini_batch_size, shuffle=True, collate_fn=datasets.utils.batch_encode) self.train(dataloader=train_dataloader, n_epochs=n_epochs, log_freq=log_freq) else: raise ValueError('Invalid training mode') def testing(self, test_datasets, **kwargs): mini_batch_size = kwargs.get('mini_batch_size') accuracies, precisions, recalls, f1s = [], [], [], [] for test_dataset in test_datasets: logger.info('Testing on {}'.format( test_dataset.__class__.__name__)) test_dataloader = data.DataLoader( test_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.batch_encode) acc, prec, rec, f1 = self.evaluate(dataloader=test_dataloader) accuracies.append(acc) precisions.append(prec) recalls.append(rec) f1s.append(f1) logger.info( 'Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, ' 'F1 score = {:.4f}'.format(np.mean(accuracies), np.mean(precisions), np.mean(recalls), np.mean(f1s)))
class MAML: def __init__(self, device, **kwargs): self.inner_lr = kwargs.get('inner_lr') self.meta_lr = kwargs.get('meta_lr') self.write_prob = kwargs.get('write_prob') self.replay_rate = kwargs.get('replay_rate') self.replay_every = kwargs.get('replay_every') self.device = device self.pn = TransformerClsModel(model_name=kwargs.get('model'), n_classes=1, max_length=kwargs.get('max_length'), device=device, hebbian=kwargs.get('hebbian')) logger.info('Loaded {} as PN'.format(self.pn.__class__.__name__)) meta_params = [p for p in self.pn.parameters() if p.requires_grad] self.meta_optimizer = AdamW(meta_params, lr=self.meta_lr) self.memory = ReplayMemory(write_prob=self.write_prob, tuple_size=3) self.loss_fn = nn.BCEWithLogitsLoss() inner_params = [p for p in self.pn.parameters() if p.requires_grad] self.inner_optimizer = optim.SGD(inner_params, lr=self.inner_lr) def save_model(self, model_path): checkpoint = self.pn.state_dict() torch.save(checkpoint, model_path) def load_model(self, model_path): checkpoint = torch.load(model_path) self.pn.load_state_dict(checkpoint) def evaluate(self, dataloader, updates, mini_batch_size): self.pn.train() support_set = [] for _ in range(updates): text, label, candidates = self.memory.read_batch( batch_size=mini_batch_size) support_set.append((text, label, candidates)) with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop task_predictions, task_labels = [], [] support_loss = [] for text, label, candidates in support_set: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data( text, label, candidates) input_dict = self.pn.encode_text( list(zip(replicated_text, replicated_relations))) output = fpn(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze(1).to( self.device) loss = self.loss_fn(output, targets) diffopt.step(loss) pred, true_labels = models.utils.make_rel_prediction( output, ranking_label) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(true_labels.tolist()) acc = models.utils.calculate_accuracy(task_predictions, task_labels) logger.info( 'Support set metrics: Loss = {:.4f}, accuracy = {:.4f}'.format( np.mean(support_loss), acc)) all_losses, all_predictions, all_labels = [], [], [] for text, label, candidates in dataloader: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data( text, label, candidates) with torch.no_grad(): input_dict = self.pn.encode_text( list(zip(replicated_text, replicated_relations))) output = fpn(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze( 1).to(self.device) loss = self.loss_fn(output, targets) loss = loss.item() pred, true_labels = models.utils.make_rel_prediction( output, ranking_label) all_losses.append(loss) all_predictions.extend(pred.tolist()) all_labels.extend(true_labels.tolist()) acc = models.utils.calculate_accuracy(all_predictions, all_labels) logger.info('Test metrics: Loss = {:.4f}, accuracy = {:.4f}'.format( np.mean(all_losses), acc)) return acc def training(self, train_datasets, **kwargs): updates = kwargs.get('updates') mini_batch_size = kwargs.get('mini_batch_size') if self.replay_rate != 0: replay_batch_freq = self.replay_every // mini_batch_size replay_freq = int( math.ceil((replay_batch_freq + 1) / (updates + 1))) replay_steps = int(self.replay_every * self.replay_rate / mini_batch_size) else: replay_freq = 0 replay_steps = 0 logger.info('Replay frequency: {}'.format(replay_freq)) logger.info('Replay steps: {}'.format(replay_steps)) concat_dataset = data.ConcatDataset(train_datasets) train_dataloader = iter( data.DataLoader(concat_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.rel_encode)) episode_id = 0 while True: self.inner_optimizer.zero_grad() support_loss, support_acc = [], [] with higher.innerloop_ctx(self.pn, self.inner_optimizer, copy_initial_weights=False, track_higher_grads=False) as (fpn, diffopt): # Inner loop support_set = [] task_predictions, task_labels = [], [] for _ in range(updates): try: text, label, candidates = next(train_dataloader) support_set.append((text, label, candidates)) except StopIteration: logger.info( 'Terminating training as all the data is seen') return for text, label, candidates in support_set: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data( text, label, candidates) input_dict = self.pn.encode_text( list(zip(replicated_text, replicated_relations))) output = fpn(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze( 1).to(self.device) loss = self.loss_fn(output, targets) diffopt.step(loss) pred, true_labels = models.utils.make_rel_prediction( output, ranking_label) support_loss.append(loss.item()) task_predictions.extend(pred.tolist()) task_labels.extend(true_labels.tolist()) self.memory.write_batch(text, label, candidates) acc = models.utils.calculate_accuracy(task_predictions, task_labels) logger.info( 'Episode {} support set: Loss = {:.4f}, accuracy = {:.4f}'. format(episode_id + 1, np.mean(support_loss), acc)) # Outer loop query_loss, query_acc = [], [] query_set = [] if self.replay_rate != 0 and (episode_id + 1) % replay_freq == 0: for _ in range(replay_steps): text, label, candidates = self.memory.read_batch( batch_size=mini_batch_size) query_set.append((text, label, candidates)) else: try: text, label, candidates = next(train_dataloader) query_set.append((text, label, candidates)) self.memory.write_batch(text, label, candidates) except StopIteration: logger.info( 'Terminating training as all the data is seen') return for text, label, candidates in query_set: replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data( text, label, candidates) input_dict = self.pn.encode_text( list(zip(replicated_text, replicated_relations))) output = fpn(input_dict) targets = torch.tensor(ranking_label).float().unsqueeze( 1).to(self.device) loss = self.loss_fn(output, targets) query_loss.append(loss.item()) pred, true_labels = models.utils.make_rel_prediction( output, ranking_label) acc = models.utils.calculate_accuracy( pred.tolist(), true_labels.tolist()) query_acc.append(acc) # PN meta gradients pn_params = [ p for p in fpn.parameters() if p.requires_grad ] meta_pn_grads = torch.autograd.grad(loss, pn_params) pn_params = [ p for p in self.pn.parameters() if p.requires_grad ] for param, meta_grad in zip(pn_params, meta_pn_grads): if param.grad is not None: param.grad += meta_grad.detach() else: param.grad = meta_grad.detach() # Meta optimizer step self.meta_optimizer.step() self.meta_optimizer.zero_grad() logger.info( 'Episode {} query set: Loss = {:.4f}, accuracy = {:.4f}'. format(episode_id + 1, np.mean(query_loss), np.mean(query_acc))) episode_id += 1 def testing(self, test_dataset, **kwargs): updates = kwargs.get('updates') mini_batch_size = kwargs.get('mini_batch_size') test_dataloader = data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False, collate_fn=datasets.utils.rel_encode) acc = self.evaluate(dataloader=test_dataloader, updates=updates, mini_batch_size=mini_batch_size) logger.info('Overall test metrics: Accuracy = {:.4f}'.format(acc)) return acc