def build_method(model: Model, task_info: TaskInfo): multi_dataset, multi_metric = task_info.build_dataset( model, seg=model.hparams.seg_data_dir, pos=model.hparams.pos_data_dir, ner=model.hparams.ner_data_dir, dep=model.hparams.dep_data_dir, sdp=model.hparams.sdp_data_dir, srl=model.hparams.srl_data_dir) def train_dataloader(self): multi_dataloader = { task: torch.utils.data.DataLoader(task_dataset[datasets.Split.TRAIN], batch_size=self.hparams.batch_size, collate_fn=collate, num_workers=self.hparams.num_workers, pin_memory=True) for task, task_dataset in multi_dataset.items() } res = MultiTaskDataloader(tau=self.hparams.tau, **multi_dataloader) return res def training_step(self, batch, batch_idx): result = self(**batch) self.log("loss", result.loss.item()) return {"loss": result.loss} def val_dataloader(self): return [ torch.utils.data.DataLoader( task_dataset[datasets.Split.VALIDATION], batch_size=self.hparams.batch_size, collate_fn=collate, num_workers=self.hparams.num_workers, pin_memory=True) for task, task_dataset in multi_dataset.items() ] def test_dataloader(self): return [ torch.utils.data.DataLoader(task_dataset[datasets.Split.TEST], batch_size=self.hparams.batch_size, collate_fn=collate, num_workers=self.hparams.num_workers, pin_memory=True) for task, task_dataset in multi_dataset.items() ] # AdamW + LR scheduler def configure_optimizers(self: Model): num_epoch_steps = sum( (len(dataset[datasets.Split.TRAIN]) + self.hparams.batch_size - 1) // self.hparams.batch_size for dataset in multi_dataset.values()) num_train_steps = num_epoch_steps * self.hparams.max_epochs optimizer, scheduler = optimization.from_argparse_args( self.hparams, model=self, num_train_steps=num_train_steps, n_transformer_layers=self.transformer.config.num_hidden_layers) return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] model.configure_optimizers = types.MethodType(configure_optimizers, model) model.train_dataloader = types.MethodType(train_dataloader, model) model.training_step = types.MethodType(training_step, model) validation_step, validation_epoch_end = task_info.validation_method( multi_metric, loss_tag='val_loss', metric_tags={ task_name: f"val_{task_module.task_info.metric_name}" for task_name, task_module in task_builder.items() }, metric_tag=f"val_{task_info.metric_name}") model.val_dataloader = types.MethodType(val_dataloader, model) model.validation_step = types.MethodType(validation_step, model) model.validation_epoch_end = types.MethodType(validation_epoch_end, model) test_step, test_epoch_end = task_info.validation_method( multi_metric, loss_tag='test_loss', metric_tags={ task_name: f"test_{task_module.task_info.metric_name}" for task_name, task_module in task_builder.items() }, metric_tag=f"test_{task_info.metric_name}") model.test_dataloader = types.MethodType(test_dataloader, model) model.test_step = types.MethodType(test_step, model) model.test_epoch_end = types.MethodType(test_epoch_end, model)
def build_method(model: Model, task_info: TaskInfo): (multi_dataset, distill_datasets, distill_datasets_extra), multi_metric = build_dataset( model, seg=model.hparams.seg_data_dir, pos=model.hparams.pos_data_dir, ner=model.hparams.ner_data_dir, dep=model.hparams.dep_data_dir, sdp=model.hparams.sdp_data_dir, srl=model.hparams.srl_data_dir) disable_distill = { 'seg': model.hparams.disable_seg, 'pos': model.hparams.disable_pos, 'ner': model.hparams.disable_ner, 'dep': model.hparams.disable_dep, 'sdp': model.hparams.disable_sdp, 'srl': model.hparams.disable_srl, } disable_distill = { task for task, disable in disable_distill.items() if disable } temperature_scheduler = flsw_temperature_scheduler_builder( beta=model.hparams.distill_beta, gamma=model.hparams.distill_gamma, base_temperature=model.hparams.temperature) def train_dataloader(self): multi_dataloader = { task: torch.utils.data.DataLoader(task_dataset, batch_size=None, num_workers=self.hparams.num_workers, pin_memory=True, shuffle=True) for task, task_dataset in distill_datasets.items() } res = MultiTaskDataloader(tau=self.hparams.tau, **multi_dataloader) return res def training_step(self: Model, batch, batch_idx): task = batch['task'] target_logits = batch.pop('logits') result = self(**batch) norm_loss = result.loss if task not in disable_distill: distill_loss = distill_loss_map[task]( batch, result, target_logits, temperature_scheduler, model, extra=distill_datasets_extra[task]) distill_loss_weight = self.global_step / self.num_train_steps loss = distill_loss_weight * norm_loss + ( 1 - distill_loss_weight) * distill_loss self.log("distill_loss", distill_loss.item()) self.log("norm_loss", norm_loss.item()) self.log("loss", loss.item()) return {"loss": loss} else: self.log("loss", norm_loss.item()) return {"loss": norm_loss} def val_dataloader(self): return [ torch.utils.data.DataLoader( task_dataset[datasets.Split.VALIDATION], batch_size=getattr(self.hparams, f'{task}_batch_size') or self.hparams.batch_size, collate_fn=collate, num_workers=self.hparams.num_workers, pin_memory=True) for task, task_dataset in multi_dataset.items() ] def test_dataloader(self): return [ torch.utils.data.DataLoader( task_dataset[datasets.Split.TEST], batch_size=getattr(self.hparams, f'{task}_batch_size') or self.hparams.batch_size, collate_fn=collate, num_workers=self.hparams.num_workers, pin_memory=True) for task, task_dataset in multi_dataset.items() ] # AdamW + LR scheduler def configure_optimizers(self: Model): num_epoch_steps = sum( len(dataset) for dataset in distill_datasets.values()) num_train_steps = num_epoch_steps * self.hparams.max_epochs setattr(self, 'num_train_steps', num_train_steps) optimizer, scheduler = optimization.from_argparse_args( self.hparams, model=self, num_train_steps=num_train_steps, n_transformer_layers=self.transformer.config.num_hidden_layers) return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] model.configure_optimizers = types.MethodType(configure_optimizers, model) model.train_dataloader = types.MethodType(train_dataloader, model) model.training_step = types.MethodType(training_step, model) validation_step, validation_epoch_end = task_info.validation_method( multi_metric, task=task_info.task_name, preffix='val') model.val_dataloader = types.MethodType(val_dataloader, model) model.validation_step = types.MethodType(validation_step, model) model.validation_epoch_end = types.MethodType(validation_epoch_end, model) test_step, test_epoch_end = task_info.validation_method( multi_metric, task=task_info.task_name, preffix='test') model.test_dataloader = types.MethodType(test_dataloader, model) model.test_step = types.MethodType(test_step, model) model.test_epoch_end = types.MethodType(test_epoch_end, model)