def get_current_sparsity(self, model: Model) -> float: trainable_params = sum(module.weight.data.numel() for name, module in model.named_modules() if isinstance(module, nn.Linear)) nonzero_params = sum(module.weight.data.nonzero().size(0) for name, module in model.named_modules() if isinstance(module, nn.Linear)) return (trainable_params - nonzero_params) / trainable_params
def torchscriptify(self, tensorizers, traced_model): output_layer = self.output_layer.torchscript_predictions() input_vocab = tensorizers["tokens"].vocab class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = Vocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) @jit.script_method def forward(self, tokens: List[List[str]]): word_ids = self.vocab.lookup_indices_2d(tokens) seq_lens = jit.annotate(List[int], []) for sentence in word_ids: seq_lens.append(len(sentence)) pad_to_length = list_max(seq_lens) for sentence in word_ids: for _ in range(pad_to_length - len(sentence)): sentence.append(self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits) return Model()
def torchscriptify(self, tensorizers, traced_model): output_layer = self.output_layer.torchscript_predictions() max_byte_len = tensorizers["token_bytes"].max_byte_len byte_offset_for_non_padding = tensorizers[ "token_bytes"].offset_for_non_padding input_vocab = tensorizers["tokens"].vocab class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = Vocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.max_byte_len = jit.Attribute(max_byte_len, int) self.byte_offset_for_non_padding = jit.Attribute( byte_offset_for_non_padding, int) self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) self.model = traced_model self.output_layer = output_layer @jit.script_method def forward(self, tokens: List[List[str]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding) logits = self.model(torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens)) return self.output_layer(logits) return Model()
def train(self, train_iter: Iterator, eval_iter: Iterator, model: Model, metric_reporter: MetricReporter, optimizer: torch.optim.Optimizer, pytext_config: PyTextConfig, scheduler=None, *args, **kwargs) -> Tuple[torch.nn.Module, Any]: print("Num of workers for Hogwild Training is {}".format( self.num_workers)) # Share memory of tensors for concurrent updates from multiple processes. if self.num_workers > 1: for param in model.parameters(): param.share_memory_() return super().train( train_iter, eval_iter, model, metric_reporter, optimizer, pytext_config, scheduler, )
def train(self, train_iter: Iterator, eval_iter: Iterator, model: Model, metric_reporter: MetricReporter, optimizers: List[torch.optim.Optimizer], pytext_config: PyTextConfig, scheduler=None, *args, **kwargs): print("Num of workers for Hogwild Training is {}".format( self.num_workers)) # Share memory of tensors for concurrent updates from multiple processes. if self.num_workers > 1: for param in model.parameters(): param.share_memory_() processes = [] for rank in range(1, self.num_workers): # Initialize the batches with different randome states. train_iter.batches.init_epoch() p = mp.Process( target=self.real_trainer.train, args=( train_iter, eval_iter, model, metric_reporter, optimizers, pytext_config, scheduler, None, rank, ), ) processes.append(p) p.start() training_result: List = Manager().list() # Actual type is ListProxy. self.real_trainer.train( train_iter, eval_iter, model, metric_reporter, optimizers, pytext_config, scheduler, training_result, rank=0, ) for p in processes: p.join() # Ony rank 0 worker writes to training_result assert len(training_result) == 1 return training_result[0] # Contains best model and best metric.
def test(self, test_task_iters: BatchPreparationPipeline, model: Model, metric_reporter: MetaLearnMetricReporter): for mbidx, meta_batch in enumerate(test_task_iters): support, target, context = meta_batch for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip( support, target, context): task = t_context['task_id'][0] model.train() model.contextualize(s_context) model(*s_inputs, responses=s_targets) # model remembers responses model.eval() with torch.no_grad(): t_pred = model(*t_inputs) t_loss = model.get_loss(t_pred, t_targets, t_context).item() metric_reporter.add_batch_stats(task, t_loss, s_inputs, t_predictions=t_pred, t_targets=t_targets) metric_reporter.report_metric(stage=Stage.TEST, epoch=0, reset=False)
def apply_masks(self, model: Model, masks: List[torch.Tensor]): """ apply given masks to zero-out learnable weights in model """ learnableparams = [p for p in model.parameters() if p.requires_grad] assert len(learnableparams) == len(masks) for m, w in zip(masks, learnableparams): assert m.size() == w.size() w.data *= m.clone()
def create_optimizer( model: Model, optimizer_params: OptimizerParams) -> List[torch.optim.Optimizer]: if optimizer_params.type == OptimizerType.ADAM: return [ torch.optim.Adam( model.get_param_groups_for_optimizer(), lr=optimizer_params.lr, weight_decay=optimizer_params.weight_decay, ) ] elif optimizer_params.type == OptimizerType.SGD: return [ torch.optim.SGD( model.get_param_groups_for_optimizer(), lr=optimizer_params.lr, momentum=optimizer_params.momentum, ) ] else: raise ValueError("Unknown optimizer type")
def predict(self, test_task_iters: BatchPreparationPipeline, model: Model, metric_reporter: MetaLearnMetricReporter): for meta_batch in test_task_iters: support, target, context = meta_batch for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip( support, target, context): task = t_context['task_id'][0] model.train() model.contextualize(s_context) model(*s_inputs, responses=s_targets) # model remembers responses model.eval() with torch.no_grad(): resps, resp_lens = model( *t_inputs ) # gets predcited response in embedded tensor. and length of it. yield dict(task=task, resps=resps, resp_lens=resp_lens, s_inputs=s_inputs, s_targets=s_targets, s_context=s_context, t_inputs=t_inputs, t_targets=t_targets, t_context=t_context)
def apply_masks(self, model: Model, masks: List[torch.Tensor]): """ apply given masks to zero-out learnable weights in model """ learnableparams = [p for p in model.parameters() if p.requires_grad] assert len(learnableparams) == len(masks) for m, w in zip(masks, learnableparams): if len(m.size()): assert m.size() == w.size() w.data *= m.clone() # if accumulate_mask, remove a param permanently by also removing # its gradient if self.accumulate_mask: w.grad.data *= m.clone()
def torchscriptify(self, tensorizers, traced_model): output_layer = self.output_layer.torchscript_predictions() input_vocab = tensorizers["tokens"].vocab class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = Vocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) @jit.script_method def forward(self, tokens: List[List[str]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits) class ModelWithDenseFeat(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = Vocabulary(input_vocab, unk_idx=input_vocab.idx[UNK]) self.normalizer = tensorizers["dense"].normalizer self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int) @jit.script_method def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]): seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) dense_feat = self.normalizer.normalize(dense_feat) logits = self.model( torch.tensor(word_ids), torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits) return ModelWithDenseFeat() if "dense" in tensorizers else Model()
def torchscriptify(self, tensorizers, traced_model): output_layer = self.output_layer.torchscript_predictions() max_seq_len = tensorizers["token_bytes"].max_seq_len or -1 max_byte_len = tensorizers["token_bytes"].max_byte_len byte_offset_for_non_padding = tensorizers[ "token_bytes"].offset_for_non_padding class Model(torch.jit.ScriptModule): def __init__(self): super().__init__() self.max_seq_len = torch.jit.Attribute(max_seq_len, int) self.max_byte_len = torch.jit.Attribute(max_byte_len, int) self.byte_offset_for_non_padding = torch.jit.Attribute( byte_offset_for_non_padding, int) self.model = traced_model self.output_layer = output_layer @torch.jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") tokens = truncate_tokens(tokens, self.max_seq_len, SpecialTokens.PAD) seq_lens = make_sequence_lengths(tokens) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding) logits = self.model(token_bytes, torch.tensor(seq_lens)) return self.output_layer(logits) return Model()
def train( self, train_iter: BatchIterator, eval_iter: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, optimizer: torch.optim.Optimizer, scheduler=None, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config optimizer (torch.optim.Optimizer): torch optimizer to be used scheduler (Optional[torch.optim.lr_scheduler]): learning rate scheduler, default is None training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ timer = time_utils.StageTimer() world_size = 1 if cuda_utils.CUDA_ENABLED: model = model.cuda() world_size = cuda_utils.DISTRIBUTED_WORLD_SIZE if world_size > 1: device_id = torch.cuda.current_device() model = DistributedModel( module=model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, ) timer.add_stage(stage="init_distributed_model") best_metric = None last_best_epoch = 0 scheduler = self._prepare_scheduler(train_iter, scheduler) timer.add_stage(stage="pre_training") def training_pre_batch_callback(): if world_size > 1: # replace optimizer.zero_grad() here to work with DDP # in cases where some parameters don't receive grads at each step # loss.backward will set grad for params in the computation graph # we can thus follow which params are left out and call .backward # on them manually for p in model.parameters(): if p.grad is not None: p.grad.detach_() p.grad = None else: optimizer.zero_grad() def training_backprop(loss, timer): loss.backward() if world_size > 1: # DDP fix when some parameters don't receive grads for p in model.parameters(): if p.requires_grad and p.grad is None: p.backward(torch.zeros_like(p.data)) timer.add_stage("backward") if scheduler: scheduler.step_batch() if self.config.max_clip_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.config.max_clip_norm) else: grad_norm = None optimizer.step() timer.add_stage("update_grads") # grad_norm could be used to check grads sync in distributed training return grad_norm time_start = time.time() for epoch in range(1, self.config.epochs + 1): if self.config.target_time_limit_seconds > 0 and epoch > 1: time_elapsed = time.time() - time_start mean_epoch_time = time_elapsed / float(epoch - 1) expected_next_epoch_time = time_elapsed + mean_epoch_time if expected_next_epoch_time > self.config.target_time_limit_seconds: print( f"Training stopped after {epoch - 1} epochs and " f"{int(time_elapsed)} seconds, due to the target max training " f"time of {self.config.target_time_limit_seconds} seconds." ) break print(f"Rank {rank} worker: Starting epoch #{epoch}") model.train() lrs = (str(lr) for lr in learning_rates(optimizer)) print(f"Learning rate(s): {', '.join(lrs)}") self._run_epoch( Stage.TRAIN, epoch, train_iter, model, metric_reporter, pre_batch=training_pre_batch_callback, backprop=training_backprop, rank=rank, ) timer.add_stage(stage=f"epoch_train") model.eval(Stage.EVAL) with torch.no_grad(): eval_metric = self._run_epoch(Stage.EVAL, epoch, eval_iter, model, metric_reporter, rank=rank) timer.add_stage(stage=f"epoch_eval") # Step the learning rate scheduler(s) if scheduler: assert eval_metric is not None scheduler.step( metrics=metric_reporter.get_model_select_metric( eval_metric), epoch=epoch, ) # choose best model. if metric_reporter.compare_metric(eval_metric, best_metric): last_best_epoch = epoch best_metric = eval_metric # Only rank = 0 trainer saves modules. if train_config.save_module_checkpoints and rank == 0: model.save_modules(base_path=train_config.modules_save_dir, suffix=f"-ep{epoch}") if rank == 0: print(f"Rank {rank} worker: Found a better model!") model_state = model.state_dict() # save to cpu to avoid multiple model copies in gpu memory if cuda_utils.CUDA_ENABLED: for key, state in model_state.items(): model_state[key] = state.cpu() best_model_state = model_state timer.add_stage(stage=f"epoch_save/load_module") if self.config.early_stop_after > 0 and ( epoch - last_best_epoch == self.config.early_stop_after): print(f"Rank {rank} worker: Eval metric hasn't changed for " + f"{self.config.early_stop_after} epochs. Stopping now.") break sys.stdout.flush() if rank == 0: if cuda_utils.CUDA_ENABLED: for key, state in best_model_state.items(): best_model_state[key] = state.cuda() model.load_state_dict(best_model_state) timer.report("Trainer train timer") return model, best_metric
def train( self, training_data: BatchIterator, eval_data: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ state = TrainingState(model=model, optimizer=self.optimizer, scheduler=self.scheduler, rank=rank) training_data = self.set_up_training(state, training_data) trainable_params = sum(p.numel() for p in state.model.parameters() if p.requires_grad) print(f"Num trainable parameters: {trainable_params}") while self.continue_training(state): state.epoch += 1 state.epochs_since_last_improvement += 1 lrs = learning_rates(state.optimizer) print(f"\nWorker {state.rank} starting epoch {state.epoch}", flush=True) print(f"Learning rate(s): {', '.join(map(str, lrs))}") with timing.time("train epoch"): state.stage = Stage.TRAIN state.model.train() print(f"start training epoch {state.epoch}", flush=True) epoch_data = training_data if self.config.num_batches_per_epoch: # We want to limit the number of batches in the epoch; # equivalent to epoch_data[:num_batches_per_epoch] for iterators. # In this case we set the training data iterator to cycle earlier # in the training process, so when it reaches the end it will # loop back to the beginning. epoch_data = itertools.islice( epoch_data, self.config.num_batches_per_epoch) self.run_epoch(state, epoch_data, metric_reporter) if not self.config.do_eval: continue with timing.time("eval epoch"): state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating epoch {state.epoch}", flush=True) with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) # Step the learning rate scheduler(s) assert eval_metric is not None state.scheduler.step_epoch( metrics=metric_reporter.get_model_select_metric(eval_metric), epoch=state.epoch, ) # Did we train a better model? better_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if better_model: self.update_best_model(state, train_config, eval_metric) if better_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) if self.optimizer.finalize(): state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating finalized state", flush=True) with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) better_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if better_model: self.update_best_model(state, train_config, eval_metric) if better_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) # Only bother loading the best model for master worker if rank == 0 and state.best_model_state is not None: self.load_best_model(state) return state.model, state.best_model_metric
def train( self, text_embedder, train_task_iters: Optional[BatchPreparationPipeline], eval_task_iters: BatchPreparationPipeline, model: Model, metric_reporter: MetaLearnMetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: diat = text_embedder.decode_ids_as_text if cuda_utils.CUDA_ENABLED: model = model.cuda() best_model_path = None # Start outer loop (meta learner "epochs") ############################################# if not train_task_iters: LOG.warning("Model does not need meta-training") else: for epoch in range(1, 2): # single epoch temp = next(train_task_iters) for bidx, (support, target, context) in zip(range(100), train_task_iters): for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip(support, target, context): # support : (2) # s_inputs : (6) # s_inputs[0].shape : (128, 3, 38) # 3 means 3 consecutive sentence ## 'denver', 'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?' # s_inputs[1].shape : (128, 3, 38, 768) # I guess BertEmbedding # s_inputs[2].shape : (128, 2, 37) # 2 means the next consecutive sentence of s_inputs[0] ## 'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?' # s_inputs[3].shape : (128) # [3, 3, 3, 3, 3....] # s_inputs[4].shape : (128, 3) # each length of sentences in s_inputs[0] # s_inputs[5].shape : (128, 2) # each length of sentences in s_inputs[2] # s_targets : (2) # s_targets[0].shape : (128, 2, 34) ## 'no, the thunderstorm has drifted north .', 'you would like the storm ?' # s_targets[1].shape : (128, 2) # each length of sentences in s_targets[0] # type(s_context) : dict # keys : {'target_seq_lens', 'orig_text', 'dlg_len', 'dlg_id', 'domain_id', 'task_id', 'index'} # s_context['target_seq_lens'].shape : (128, 2) # each length"+1" of sentences in s_targets[0] # s_context['orig_text'].__len__() : 128 # s_context['orig_text'][0]'s original text == "turns": ["Hello how may I help you?", "Is there still supposed to be a thunderstorm today as there was originally?", "what location?", "Denver", "No, the thunderstorm has drifted north.", "That makes me mad! Why is that?", "You would like the storm?", "Yes! It really upsets me that there isn't goin g to be one now.", "I'm sorry, I will contact mother nature immediately!", "Why is there not going to be one?", "The radar say so."] # s_context['dlg_len'] = 4 # s_context['dlg_id'] : (128) # '2d1d4ed2', '20debe73', ... ## "id" # s_context['domain_id'] : (128) # 'WEATHER_CHECK', 'WEATHER_CHECK'... ## "domain" # s_context['task_id'] : (128) # 'd941f2bb', '5f2bb1b2', ... ## "task_id" # s_context['index'] : (128) # 25650, 25414, 25454, 25445, 25465, 25370, 25333, 25411, 25203, 25108, 25631, 25532, 25155, 25472, 25365, 25356, 25258, 25282, 25242, 25518, 25150, 25237, 25372 # t_inputs : (6) # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?' task = t_context['task_id'][0] # Adapt the model using the support set model.train() for step in range(1): model.contextualize(s_context) model(*s_inputs, responses=s_targets) # model remembers responses # Evaluate the model using the target set model.eval() # model now retrieves from examples seen so far model.contextualize(t_context) t_pred = model(*t_inputs) t_loss = model.get_loss(t_pred, t_targets, t_context).item() metric_reporter.add_batch_stats(task, t_loss, s_inputs, t_predictions=t_pred, t_targets=t_targets) metric_reporter.report_metric(stage=Stage.TRAIN, epoch=epoch, reset=False) logging.info("Evaluating model on eval tasks") with torch.no_grad(): for bidx, (support, target, context) in enumerate(eval_task_iters): for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip(support, target, context): task = t_context["task_id"][0] model.train() model.contextualize(s_context) model(*s_inputs, responses=s_targets) # model remembers responses model.eval() t_pred = model(*t_inputs) t_loss = model.get_loss(t_pred, t_targets, t_context).item() metric_reporter.add_batch_stats(task, t_loss, s_inputs, t_predictions=t_pred, t_targets=t_targets) metric_reporter.report_metric(stage=Stage.EVAL, epoch=epoch, reset=False) best_model_path = os.path.join( train_config.modules_save_dir, "model.pt" ) torch.save(model.state_dict(), best_model_path) return model, None
def train( self, train_iter: BatchIterator, eval_iter: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, optimizers: List[torch.optim.Optimizer], scheduler=None, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: if cuda_utils.CUDA_ENABLED: model = model.cuda() if cuda_utils.DISTRIBUTED_WORLD_SIZE > 1: device_id = torch.cuda.current_device() model = DistributedModel( module=model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, ) best_metric = None last_best_epoch = 0 best_model_path = None scheduler = self._prepare_scheduler(train_iter, scheduler) def training_pre_batch_callback(): optimizer_zero_grad(optimizers) def training_backprop(loss): loss.backward() if scheduler: scheduler.step_batch() if self.config.max_clip_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.config.max_clip_norm) else: grad_norm = None optimizer_step(optimizers) # grad_norm could be used to check grads sync in distributed training return grad_norm len_sched_ix = 0 # Used since we need the infinite iterator (only created and called once) def batch_generator_for_epoch(it): n = len(it) while n > 0: yield next(it) n -= 1 for epoch in range(self.config.start_epoch, self.config.epochs + 1): # Set the dialogue length in the fields, to be used by the postprocessor while self.config.length_schedule_per_epoch \ and len_sched_ix < len(self.config.length_schedule_per_epoch) \ and epoch >= self.config.length_schedule_per_epoch[len_sched_ix][0]: train_iter.max_n_turns = \ self.config.length_schedule_per_epoch[len_sched_ix][1] eval_iter.max_n_turns = \ self.config.length_schedule_per_epoch[len_sched_ix][1] len_sched_ix += 1 LOG.info(f"\nRank {rank} worker: Starting epoch #{epoch}") model.train() lrs = (str(lr) for lr in learning_rates(optimizers)) LOG.info(f"Learning rate(s): {', '.join(lrs)}") self._run_epoch( Stage.TRAIN, epoch, batch_generator_for_epoch(train_iter), model, metric_reporter, pre_batch=training_pre_batch_callback, backprop=training_backprop, rank=rank, ) model.eval(Stage.EVAL) with torch.no_grad(): eval_metric = self._run_epoch( Stage.EVAL, epoch, batch_generator_for_epoch(eval_iter), model, metric_reporter, rank=rank) # Step the learning rate scheduler(s) if scheduler: assert eval_metric is not None scheduler.step( metrics=metric_reporter.get_model_select_metric( eval_metric), epoch=epoch, ) # choose best model. if metric_reporter.compare_metric(eval_metric, best_metric): LOG.info( f"Rank {rank} worker: Found a better model! Saving the model state for epoch #{epoch}." ) last_best_epoch = epoch best_metric = eval_metric # Only rank = 0 trainer saves modules. if train_config.save_module_checkpoints and rank == 0: best_model_path = os.path.join( train_config.modules_save_dir, "best_model") optimizer, = optimizers # PyText only ever returns a single optimizer in this list torch.save( ModelState( epoch=epoch, parameters=model.state_dict(), optimizer=optimizer.state_dict(), ), best_model_path) if (self.config.early_stop_after > 0 and (epoch - last_best_epoch == self.config.early_stop_after)): LOG.info( f"Rank {rank} worker: Eval metric hasn't changed for " f"{self.config.early_stop_after} epochs. Stopping now.") break sys.stdout.flush() train_iter.close() eval_iter.close() model.load_state_dict(torch.load(best_model_path).parameters) return model, best_metric
def train( self, training_data: BatchIterator, eval_data: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ state = TrainingState(model=model, optimizer=self.optimizer, scheduler=self.scheduler, rank=rank) self.set_up_training(state, training_data) while self.continue_training(state): state.epoch += 1 state.epochs_since_last_improvement += 1 print(f"Worker {state.rank} starting epoch {state.epoch}", flush=True) lrs = learning_rates(state.optimizer) print(f"Learning rate(s): {', '.join(map(str, lrs))}") with timing.time("train epoch"): state.stage = Stage.TRAIN state.model.train() self.run_epoch(state, training_data, metric_reporter) if not self.config.do_eval: continue with timing.time("eval epoch"): state.stage = Stage.EVAL model.eval(Stage.EVAL) with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) # Step the learning rate scheduler(s) assert eval_metric is not None state.scheduler.step_epoch( metrics=metric_reporter.get_model_select_metric(eval_metric), epoch=state.epoch, ) # Did we train a better model? if metric_reporter.compare_metric(eval_metric, state.best_model_metric): state.epochs_since_last_improvement = 0 state.best_model_metric = eval_metric self.save_checkpoint(state, train_config) # Only bother loading the best model for master worker if rank == 0 and state.best_model_state is not None: self.load_best_model(state) return state.model, state.best_model_metric
def torchscriptify(self, tensorizers, traced_model): output_layer = self.output_layer.torchscript_predictions() max_seq_len = tensorizers["tokens"].max_seq_len or -1 max_byte_len = tensorizers["token_bytes"].max_byte_len byte_offset_for_non_padding = tensorizers["token_bytes"].offset_for_non_padding input_vocab = tensorizers["tokens"].vocab class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.max_seq_len = jit.Attribute(max_seq_len, int) self.max_byte_len = jit.Attribute(max_byte_len, int) self.byte_offset_for_non_padding = jit.Attribute( byte_offset_for_non_padding, int ) self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.model = traced_model self.output_layer = output_layer @jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding ) logits = self.model( torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens) ) return self.output_layer(logits) class ModelWithDenseFeat(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.normalizer = tensorizers["dense"].normalizer self.max_seq_len = jit.Attribute(max_seq_len, int) self.max_byte_len = jit.Attribute(max_byte_len, int) self.byte_offset_for_non_padding = jit.Attribute( byte_offset_for_non_padding, int ) self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.model = traced_model self.output_layer = output_layer @jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, dense_feat: Optional[List[List[float]]] = None, ): if tokens is None: raise RuntimeError("tokens is required") if dense_feat is None: raise RuntimeError("dense_feat is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) token_bytes, _ = make_byte_inputs( tokens, self.max_byte_len, self.byte_offset_for_non_padding ) dense_feat = self.normalizer.normalize(dense_feat) logits = self.model( torch.tensor(word_ids), token_bytes, torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits) return ModelWithDenseFeat() if "dense" in tensorizers else Model()
def torchscriptify(self, tensorizers, traced_model): output_layer = self.output_layer.torchscript_predictions() input_vocab = tensorizers["tokens"].vocab max_seq_len = tensorizers["tokens"].max_seq_len or -1 scripted_tokenizer: Optional[torch.jit.ScriptModule] = None try: scripted_tokenizer = tensorizers[ "tokens"].tokenizer.torchscriptify() except NotImplementedError: pass if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer): scripted_tokenizer = None """ The input tensor packing memory is allocated/cached for different shapes, and max sequence length will help to reduce the number of different tensor shapes. We noticed that the TorchScript model could use 25G for offline inference on CPU without using max_seq_len. """ class Model(torch.jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.model = traced_model self.output_layer = output_layer self.pad_idx = torch.jit.Attribute(input_vocab.get_pad_index(), int) self.max_seq_len = torch.jit.Attribute(max_seq_len, int) self.tokenizer = scripted_tokenizer @torch.jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): # PyTorch breaks with 2 'not None' checks right now. if texts is not None: if tokens is not None: raise RuntimeError("Can't set both tokens and texts") if self.tokenizer is not None: tokens = [[ t[0] for t in self.tokenizer.tokenize(text) ] for text in texts] if tokens is None: raise RuntimeError("tokens is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits) return Model()
def get_current_sparsity(self, model: Model) -> float: trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) nonzero_params = sum(p.nonzero().size(0) for p in model.parameters() if p.requires_grad) return (trainable_params - nonzero_params) / trainable_params
def torchscriptify(self, tensorizers, traced_model): output_layer = self.output_layer.torchscript_predictions() input_vocab = tensorizers["tokens"].vocab max_seq_len = tensorizers["tokens"].max_seq_len or -1 """ The input tensor packing memory is allocated/cached for different shapes, and max sequence length will help to reduce the number of different tensor shapes. We noticed that the TorchScript model could use 25G for offline inference on CPU without using max_seq_len. """ class Model(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.max_seq_len = jit.Attribute(max_seq_len, int) @jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): if tokens is None: raise RuntimeError("tokens is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens)) return self.output_layer(logits) class ModelWithDenseFeat(jit.ScriptModule): def __init__(self): super().__init__() self.vocab = ScriptVocabulary( input_vocab, input_vocab.get_unk_index(), input_vocab.get_pad_index(), ) self.normalizer = tensorizers["dense"].normalizer self.model = traced_model self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.max_seq_len = jit.Attribute(max_seq_len, int) @jit.script_method def forward( self, texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, dense_feat: Optional[List[List[float]]] = None, ): if tokens is None: raise RuntimeError("tokens is required") if dense_feat is None: raise RuntimeError("dense_feat is required") tokens = truncate_tokens(tokens, self.max_seq_len, self.vocab.pad_token) seq_lens = make_sequence_lengths(tokens) word_ids = self.vocab.lookup_indices_2d(tokens) word_ids = pad_2d(word_ids, seq_lens, self.pad_idx) dense_feat = self.normalizer.normalize(dense_feat) logits = self.model( torch.tensor(word_ids), torch.tensor(seq_lens), torch.tensor(dense_feat, dtype=torch.float), ) return self.output_layer(logits) return ModelWithDenseFeat() if "dense" in tensorizers else Model()
def get_masks(self, model: Model, pre_masks: List[torch.Tensor] = None) -> List[torch.Tensor]: """ Note: this function returns the masks only but do not sparsify or modify the weights prune x% of weights among the weights with "1" in pre_masks Args: model: Model pre_masks: list of FloatTensors where "1" means retained the weight and "0" means pruned the weight Return: masks: List[torch.Tensor], intersection of new masks and pre_masks, so that "1" only if the weight is selected after new masking and pre_mask """ learnableparams = [p for p in model.parameters() if p.requires_grad] if pre_masks: self._masks = pre_masks if self._masks is None: # retain everything if no pre_masks given self._masks = [torch.ones_like(p) for p in learnableparams] assert len(learnableparams) == len(self._masks) for m, w in zip(self._masks, learnableparams): if len(m.size()): assert m.size() == w.size() if self.layerwise_pruning: masks = [] for m, param in zip(self._masks, learnableparams): weights_abs = torch.abs(param.data).to(param.device) # absolute value of weights selected from existent masks weights_abs_masked_flat = torch.flatten(weights_abs[m.bool()]) total_size = weights_abs_masked_flat.numel() if total_size > 0: # using ceil instead of floor() or int() # because at least one element in the tensor required to be selected max_num_nonzeros = math.ceil(total_size * (1 - self.sparsity)) # only pruned among the weights slected from existent masks topkval = (torch.topk( weights_abs_masked_flat, max_num_nonzeros).values.min().item()) # intersection of the new mask and pre_mexistent masks, # mask == 1 retain, mask == 0 pruned, mask = (weights_abs >= topkval).float() * m else: mask = param.new_empty(()) masks.append(mask) else: # concatenated flatten tensor of learnableparams that have _masks as True learnableparams_masked_flat = torch.cat( [ torch.flatten(p[m.bool()]) for m, p in zip(self._masks, learnableparams) ], dim=0, ) # using ceil instead of floor() or int() because at least one element # in the tensor required to be selected max_num_nonzeros = math.ceil(learnableparams_masked_flat.numel() * (1 - self.sparsity)) # select globally the top-k th weight among weights selected from _masks topkval = (torch.topk(torch.abs(learnableparams_masked_flat), max_num_nonzeros).values.min().item()) # intersection of the new mask and _masks, # mask == 1 retain, mask == 0 pruned, masks = [(torch.abs(p.data) >= topkval).float() * m if p.numel() > 0 else p.new_empty(()) for m, p in zip(self._masks, learnableparams)] if self.accumulate_mask: self._masks = masks return masks
def train( self, train_task_iters: Optional[BatchPreparationPipeline], eval_task_iters: BatchPreparationPipeline, model: Model, metric_reporter: MetaLearnMetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: if cuda_utils.CUDA_ENABLED: model = model.cuda() best_model_path = None # Start outer loop (meta learner "epochs") ############################################# if not train_task_iters: LOG.warning("Model does not need meta-training") else: for epoch in range(1, 2): # single epoch for bidx, (support, target, context) in zip(range(100), train_task_iters): for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip( support, target, context): task = t_context['task_id'][0] # Adapt the model using the support set model.train() for step in range(1): model.contextualize(s_context) model(*s_inputs, responses=s_targets ) # model remembers responses # Evaluate the model using the target set model.eval( ) # model now retrieves from examples seen so far model.contextualize(t_context) t_pred = model(*t_inputs) t_loss = model.get_loss(t_pred, t_targets, t_context).item() metric_reporter.add_batch_stats(task, t_loss, s_inputs, t_predictions=t_pred, t_targets=t_targets) metric_reporter.report_metric(stage=Stage.TRAIN, epoch=epoch, reset=False) logging.info("Evaluating model on eval tasks") with torch.no_grad(): for bidx, (support, target, context) in enumerate(eval_task_iters): for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip( support, target, context): task = t_context["task_id"][0] model.train() model.contextualize(s_context) model(*s_inputs, responses=s_targets) # model remembers responses model.eval() t_pred = model(*t_inputs) t_loss = model.get_loss(t_pred, t_targets, t_context).item() metric_reporter.add_batch_stats(task, t_loss, s_inputs, t_predictions=t_pred, t_targets=t_targets) metric_reporter.report_metric(stage=Stage.EVAL, epoch=epoch, reset=False) best_model_path = os.path.join(train_config.modules_save_dir, "model.pt") torch.save(model.state_dict(), best_model_path) return model, None
def train( self, train_iter: BatchIterator, eval_iter: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, optimizers: List[torch.optim.Optimizer], scheduler=None, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config optimizers (List[torch.optim.Optimizer]): a list of torch optimizers, in most of the case only contains one optimizer scheduler (Optional[torch.optim.lr_scheduler]): learning rate scheduler, default is None training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ if cuda_utils.CUDA_ENABLED: model = model.cuda() if cuda_utils.DISTRIBUTED_WORLD_SIZE > 1: device_id = torch.cuda.current_device() model = DistributedModel( module=model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, ) best_metric = None last_best_epoch = 0 best_model_state = None scheduler = self._prepare_scheduler(train_iter, scheduler) def training_pre_batch_callback(): optimizer_zero_grad(optimizers) def training_backprop(loss): loss.backward() if scheduler: scheduler.step_batch() if self.config.max_clip_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), self.config.max_clip_norm) else: grad_norm = None optimizer_step(optimizers) # grad_norm could be used to check grads sync in distributed training return grad_norm for epoch in range(1, self.config.epochs + 1): print(f"Rank {rank} worker: Starting epoch #{epoch}") model.train() lrs = (str(lr) for lr in learning_rates(optimizers)) print(f"Learning rate(s): {', '.join(lrs)}") self._run_epoch( Stage.TRAIN, epoch, train_iter, model, metric_reporter, pre_batch=training_pre_batch_callback, backprop=training_backprop, rank=rank, ) model.eval(Stage.EVAL) eval_metric = self._run_epoch(Stage.EVAL, epoch, eval_iter, model, metric_reporter, rank=rank) # Step the learning rate scheduler(s) if scheduler: assert eval_metric is not None scheduler.step( metrics=metric_reporter.get_model_select_metric( eval_metric), epoch=epoch, ) # choose best model. if metric_reporter.compare_metric(eval_metric, best_metric): print( f"Rank {rank} worker: Found a better model! Saving the model state." ) last_best_epoch = epoch best_metric = eval_metric # Only rank = 0 trainer saves modules. if train_config.save_module_checkpoints and rank == 0: model.save_modules(base_path=train_config.modules_save_dir, suffix=f"-ep{epoch}") best_model_state = copy.deepcopy(model.state_dict()) if self.config.early_stop_after > 0 and ( epoch - last_best_epoch == self.config.early_stop_after): print(f"Rank {rank} worker: Eval metric hasn't changed for " + f"{self.config.early_stop_after} epochs. Stopping now.") break sys.stdout.flush() model.load_state_dict(best_model_state) return model, best_metric
def train( self, text_embedder, train_task_iters: Optional[BatchPreparationPipeline], # Optional[X] is equivalent to Union[X, None]. eval_task_iters: BatchPreparationPipeline, model: Model, metric_reporter: MetaLearnMetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: if cuda_utils.CUDA_ENABLED: model = model.cuda() best_model_path = None meta_lr = 0.001 update_lr = 0.01 from pytorch_transformers import AdamW if model.representation.gptmode == 'gpt2': meta_optim = AdamW(model.parameters(), lr=meta_lr) else: meta_optim = OpenAIAdam(model.parameters(), lr=meta_lr) # Start outer loop (meta learner "epochs") ############################################# if not train_task_iters: LOG.warning("Model does not need meta-training") else: logging.info("Training model on train tasks") for epoch in range(1, 2): # single epoch for bidx, (support, target, context) in zip(range(100), train_task_iters): # 100 different tasks # support.__len__() : task num #class MetaDataHandler(DialogueDataHandler): # class Config(DialogueDataHandler.Config): # # Support set size per task, i.e. base-learner minibatch size # support_batch_size: int = 64 # 128 # meta_batch_size: int = 4 # 2 losses_q = [0 for ] print("support.__len__() ", support.__len__()) for enum_i, ((s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context)) in enumerate(zip(support, target, context)): # task num # same task support_set = s_inputs target_set = t_inputs # all same domain # support : (2) # s_inputs : (6) # s_inputs[0].shape : (128, 3, 38) # 3 means 3 consecutive sentence ## 'denver', 'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?' # s_inputs[1].shape : (128, 3, 38, 768) # I guess BertEmbedding ## Now None!! # s_inputs[2].shape : (128, 2, 37) # 2 means the next consecutive sentence of s_inputs[0] ## 'no , the thunderstorm has drifted north .', 'that makes me mad ! why is that ?' # s_inputs[3].shape : (128) # [3, 3, 3, 3, 3....] # s_inputs[4].shape : (128, 3) # each length of sentences in s_inputs[0] # s_inputs[5].shape : (128, 2) # each length of sentences in s_inputs[2] # s_targets : (2) # s_targets[0].shape : (128, 2, 34) ## 'no, the thunderstorm has drifted north .', 'you would like the storm ?' # s_targets[1].shape : (128, 2) # each length of sentences in s_targets[0] # type(s_context) : dict # keys : {'target_seq_lens', 'orig_text', 'dlg_len', 'dlg_id', 'domain_id', 'task_id', 'index'} # s_context['target_seq_lens'].shape : (128, 2) # each length"+1" of sentences in s_targets[0] # s_context['orig_text'].__len__() : 128 # s_context['orig_text'][0]'s original text == "turns": ["Hello how may I help you?", "Is there still supposed to be a thunderstorm today as there was originally?", "what location?", "Denver", "No, the thunderstorm has drifted north.", "That makes me mad! Why is that?", "You would like the storm?", "Yes! It really upsets me that there isn't goin g to be one now.", "I'm sorry, I will contact mother nature immediately!", "Why is there not going to be one?", "The radar say so."] # s_context['dlg_len'] = 4 # s_context['dlg_id'] : (128) # '2d1d4ed2', '20debe73', ... ## "id" # s_context['domain_id'] : (128) # 'WEATHER_CHECK', 'WEATHER_CHECK'... ## "domain" # s_context['task_id'] : (128) # 'd941f2bb', '5f2bb1b2', ... ## "task_id" # s_context['index'] : (128) # 25650, 25414, 25454, 25445, 25465, 25370, 25333, 25411, 25203, 25108, 25631, 25532, 25155, 25472, 25365, 25356, 25258, 25282, 25242, 25518, 25150, 25237, 25372 # t_inputs : (6) # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?' # mldc/data/data_handler.py def _train_input_from_batch(self, batch): # seq_input = getattr(batch, ModelInput.SEQ) # seq_input (4) # (128, 5, 35), (128) n seqs, (128, 5) n words per seq, None # target = getattr(batch, ModelOutput.TOK) # (2) (128, 48), (128) # teacher_forcing_input, teacher_forcing_lens = self._make_teacher_forcing(*target) # return (# flatten the seq input into the list of parameters # seq_input[0], # (128, 5, 35) # seq_input[3], # None # teacher_forcing_input, # seq_input[1], # n seqs # seq_input[2], # n words per seq # teacher_forcing_lens, # n words per output seq diat = text_embedder.decode_ids_as_text task = t_context['task_id'][0] s_domain = s_context['domain_id'][0] #t_domain = t_context['domain_id'][0] print("b_idx", bidx, "enum_i", enum_i,"s_domain :", s_domain) #print("t_domain :", s_domain) #print("task :", task) # text_embedder.decode_ids_as_text(s_inputs[0][0][0].cpu().numpy()) = 'what is your order number ?' # inputs input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids # TODO num_instance = support_set[0].shape[0] # Adapt the model usingthe support set model.train() #spt_input_ids, spt_mc_token_ids, spt_lm_labels, spt_mc_labels, spt_token_type_ids = support_set #for s_idx, (sii, smti, sll, sml, stti) in enumerate(zip(spt_input_ids, spt_mc_token_ids, # spt_lm_labels, spt_mc_labels, # spt_token_type_ids)): for s_idx, support_ins in enumerate(zip(*support_set)): sii, smti, sll, sml, stti = support_ins if model.representation.gptmode == "gpt2": lm_loss, mc_loss, _, _, _ = model(*support_ins) else: lm_loss, mc_loss = model(*support_ins) loss = (lm_loss * 2 + mc_loss * 1) grad = torch.autograd.grad(loss, model.parameters()) fast_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, model.parameters()))) ## input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None, #task_num = s_inputs.shape[0] # batchsz #for task_idx in range(task_num): # s_inputs_task = s_inputs[task_idx] # Adapt the model using the support set # model.train() # for step in range(1): # #model.contextualize(s_context) # #model(*s_inputs, responses=s_targets) # model remembers responses # lm_loss, mc_loss, _, _, _ = model(*s_inputs) # # Evaluate the model using the target set # model.eval() # model now retrieves from examples seen so far # model.contextualize(t_context) # t_pred = model(*t_inputs) # t_loss = model.get_loss(t_pred, t_targets, t_context).item() # metric_reporter.add_batch_stats(task, t_loss, s_inputs, # t_predictions=t_pred, t_targets=t_targets) metric_reporter.report_metric(stage=Stage.TRAIN, epoch=epoch, reset=False) import ipdb; ipdb.set_trace() logging.info("Evaluating model on eval tasks") with torch.no_grad(): for bidx, (support, target, context) in enumerate(eval_task_iters): for (s_inputs, t_inputs), (s_targets, t_targets), (s_context, t_context) in zip(support, target, context): task = t_context["task_id"][0] model.train() model.contextualize(s_context) model(*s_inputs, responses=s_targets) # model remembers responses model.eval() t_pred = model(*t_inputs) t_loss = model.get_loss(t_pred, t_targets, t_context).item() metric_reporter.add_batch_stats(task, t_loss, s_inputs, t_predictions=t_pred, t_targets=t_targets) metric_reporter.report_metric(stage=Stage.EVAL, epoch=epoch, reset=False) best_model_path = os.path.join( train_config.modules_save_dir, "model.pt" ) torch.save(model.state_dict(), best_model_path) return model, None
def get_sparsifiable_params(self, model: Model): sparsifiable_params = [ p for p in model.parameters() if p.requires_grad ] return sparsifiable_params