def run(self, block_name: str = None) -> bool: """Run the evaluation. Returns ------ bool Whether the component should continue running. """ self.model.to(self.device) self.model.eval() with torch.no_grad(): preds, targets = [], [] for batch in self._eval_iterator: pred, target = self.model(*[t.to(self.device) for t in batch]) preds.append(pred.cpu()) targets.append(target.cpu()) preds = torch.cat(preds, dim=0) # type: ignore targets = torch.cat(targets, dim=0) # type: ignore self.eval_metric = self.metric_fn(preds, targets).item() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log( f'{tb_prefix}Eval {self.metric_fn}', # type: ignore self.eval_metric, global_step=0) # type: ignore continue_ = False # Single step so don't continue return continue_
def run(self, block_name: str = None) -> bool: """Run the evaluation. Returns ------ bool Whether the component should continue running. """ self.model.to(self.device) self.model.eval() with torch.no_grad(): metric_state: Dict = {} for batch in self._eval_iterator: pred, target = self.model(*[t.to(self.device) for t in batch]) self.metric_fn.aggregate(metric_state, pred, target) self.eval_metric = self.metric_fn.finalize(metric_state) tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log( f'{tb_prefix}Eval/{self.metric_fn}', # type: ignore self.eval_metric, global_step=0) # type: ignore return False
def _eval_step(self) -> None: if self.teacher_translator is not None and self._step == 0: self.teacher_translator.initialize(self.teacher) iterator = self.train_sampler.sample(self.dataset.train, 1) batch_size = self.train_sampler.batch_size drop_last = self.train_sampler.drop_last shuffle = self.train_sampler.shuffle srcs = [] new_contexts = [] new_words = [] for batch in iterator: batch = (t.to(self.device) for t in batch) src, tgt_context, tgt_words = batch with torch.no_grad(): new_tgt, new_src = self.teacher_translator(src, src) new_tgt_context = new_tgt[:, :-1] new_tgt_words = new_tgt[:, 1:] srcs.append(new_src[:, :-1].cpu()) new_contexts.append(new_tgt_context.cpu()) new_words.append(new_tgt_words.cpu()) srcs = torch.cat(srcs, dim=0) new_contexts = torch.cat(new_contexts, dim=0) new_words = torch.cat(new_words, dim=0) original_data = (srcs, new_contexts, new_words) train_sets = [original_data] self.train_sampler = TensorSampler(train_sets, probs=[1.0], batch_size=batch_size, drop_last=drop_last, shuffle=shuffle) self._create_train_iterator() super()._eval_step() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" # Log beta if not self.use_iter: self.beta = self.get_beta(self._step) log(f'{tb_prefix}Training/Beta', self.beta, self._step)
def _log_metrics(log_prefix: str, metrics_with_states: List[Tuple], global_step: int) -> None: """Logs all provided metrics Iterates through the provided list of metrics with states, finalizes the metric, and logs it. Parameters ---------- log_prefix: str A string, such as a tensorboard prefix metrics_with_states: List[Tuple[Metric, Dict]] a list of metric-state tuples global_step: int the global step for loggin """ for metric, state in metrics_with_states: log(f'{log_prefix}/{metric}', metric.finalize(state), global_step)
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" with torch.enable_grad(): for i in range(self.iter_per_step): # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch try: batch = next(self._train_iterator) except StopIteration: self._create_train_iterator() batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss loss = self._compute_loss(batch) / self.batches_per_iter accumulated_loss += loss.item() loss.backward() # Log loss global_step = (self.iter_per_step * self._step) + i # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step) log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step) # Optimize self.optimizer.step() # Update iter scheduler if self.iter_scheduler is not None: learning_rate = self.iter_scheduler.get_lr()[ 0] # type: ignore log(f'{tb_prefix}Training/LR', learning_rate, global_step) self.iter_scheduler.step() # type: ignore # Zero the gradients when exiting a train step self.optimizer.zero_grad()
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" with torch.enable_grad(): for i in range(self.iter_per_step): # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss loss = self._compute_loss(batch) / self.batches_per_iter accumulated_loss += loss.item() loss.backward() # Log loss global_step = (self.iter_per_step * self._step) + i log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step) log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step) # Optimize self.optimizer.step() # Zero the gradients when exiting a train step self.optimizer.zero_grad()
def _eval_step(self) -> None: """Run an evaluation step over the validation data.""" self.model.eval() metric_fn_state: Dict[Metric, Dict] = {} metrics_with_states: List[Tuple] = \ [(metric, {}) for metric in self.validation_metrics] # Initialize a 1-epoch iteration through the validation set val_iterator = self.val_sampler.sample(self.dataset.val) with torch.no_grad(): loss = [] for batch in val_iterator: _, _, batch_loss = self._compute_batch( batch, [(self.metric_fn, metric_fn_state), *metrics_with_states]) loss.append(batch_loss.item()) val_loss = np.NaN if loss == [] else sum(loss) / len(loss) val_metric = self.metric_fn.finalize(metric_fn_state) # Update best model sign = (-1)**(self.lower_is_better) if self._best_metric is None or (sign * val_metric > sign * self._best_metric): self._best_metric = val_metric best_model_state = self.model.state_dict() for k, t in best_model_state.items(): best_model_state[k] = t.cpu().detach() self._best_model = best_model_state # Update scheduler if self.scheduler is not None: if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(val_loss) else: # torch's _LRScheduler.step DOES have a default value # so passing in no args is fine; it will automatically # compute the current epoch self.scheduler.step() # type: ignore tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" # Log metrics log(f'{tb_prefix}Validation/Loss', val_loss, self._step) log(f'{tb_prefix}Validation/{self.metric_fn}', val_metric, self._step) log(f'{tb_prefix}Best/{self.metric_fn}', self._best_metric, self._step) # type: ignore for (metric, state) in metrics_with_states: log(f'{tb_prefix}Validation/{metric}', metric.finalize(state), self._step) # type: ignore
def _eval_step(self) -> None: """Run an evaluation step over the validation data.""" self.model.eval() # Initialize a 1-epoch iteration through the validation set val_iterator = self.val_sampler.sample(self.dataset.val) with torch.no_grad(): preds, targets = self._aggregate_preds(val_iterator) val_loss = self.loss_fn(preds, targets).item() val_metric = self.metric_fn(preds, targets).item() # Update best model sign = (-1)**(self.lower_is_better) if self._best_metric is None or (sign * val_metric > sign * self._best_metric): self._best_metric = val_metric best_model_state = self.model.state_dict() for k, t in best_model_state.items(): best_model_state[k] = t.cpu().detach() self._best_model = best_model_state # Update scheduler if self.scheduler is not None: if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(val_loss) else: # torch's _LRScheduler.step DOES have a default value # so passing in no args is fine; it will automatically # compute the current epoch self.scheduler.step() # type: ignore tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" # Log metrics log(f'{tb_prefix}Validation/Loss', val_loss, self._step) log(f'{tb_prefix}Validation/{self.metric_fn}', val_metric, self._step) log(f'{tb_prefix}Best/{self.metric_fn}', self._best_metric, self._step) # type: ignore for metric_name, metric in self.extra_validation_metrics.items(): log(f'{tb_prefix}Validation/{metric_name}', metric(preds, targets).item(), self._step) # type: ignore
def _compute_loss(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor: if self.use_iter: tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log(f'{tb_prefix}Training/Beta', self.beta, self.global_iter) if self.top_k != 'None': loss = self._compute_loss_top_k(batch) if self.use_iter: self.global_iter += 1 self.beta = self.get_beta(self.global_iter) return loss if self.encode_during_sampling: loss = self._encode_and_sample_rnn(batch) if self.use_iter: self.global_iter += 1 self.beta = self.get_beta(self.global_iter) return loss src, tgt_context, tgt_words = batch # Sample from translator self.model.eval() dist = Categorical(torch.tensor([self.beta, 1 - self.beta])) samp_mask = (dist.sample((src.size(0), )) == 1) if torch.sum(samp_mask).item() > 0: samp_src = src[samp_mask] with torch.no_grad(): A = time.time() samp_tgt_context = self.translator(samp_src) print('FIRST BLOCK ' + str(time.time() - A)) A = time.time() samp_tgt_logits = self.teacher(samp_src, samp_tgt_context) _, samp_tgt_words = samp_tgt_logits.max(dim=-1) print('SECOND BLOCK ' + str(time.time() - A)) eos_mask = (samp_tgt_context == self.translator.tgt_eos_idx) eos_mask |= (samp_tgt_context == self.translator.tgt_pad_idx) samp_tgt_words[eos_mask] = self.translator.tgt_pad_idx # Merge original and sampled data orig_src = src[~samp_mask] orig_tgt_context = tgt_context[~samp_mask] orig_tgt_words = tgt_words[~samp_mask] # Add padding if necessary tgt_pad_idx = torch.tensor(self.translator.tgt_pad_idx) diff = samp_tgt_context.size(1) - orig_tgt_context.size(1) if diff > 0: extra = tgt_pad_idx.repeat( (orig_tgt_context.size(0), diff)).to(self.device) orig_tgt_context = torch.cat([orig_tgt_context, extra], dim=1) diff = samp_tgt_words.size(1) - orig_tgt_words.size(1) if diff > 0: extra = tgt_pad_idx.repeat( (orig_tgt_words.size(0), diff)).to(self.device) orig_tgt_words = torch.cat([orig_tgt_words, extra], dim=1) new_src = torch.cat([orig_src, samp_src], dim=0) new_tgt_context = torch.cat([orig_tgt_context, samp_tgt_context], dim=0) new_tgt_words = torch.cat([orig_tgt_words, samp_tgt_words], dim=0) else: new_src = src new_tgt_context = tgt_context new_tgt_words = tgt_words # Train model on merged data A = time.time() self.model.train() pred, target = self.model(new_src, new_tgt_context, new_tgt_words) loss = self.loss_fn(pred, target) print('THIRD BLOCK ' + str(time.time() - A)) if self.use_iter: self.global_iter += 1 self.beta = self.get_beta(self.global_iter) return loss
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() metrics_with_states: List[Tuple] = [ (metric, {}) for metric in self.training_metrics ] self._last_train_log_step = 0 log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log_prefix += 'Training' with torch.enable_grad(): for i in range(self.iter_per_step): # print('MEMORY ALLOCATED %f' % float(torch.cuda.memory_allocated() / BYTES_IN_GB)) # print('MEMORY CACHED %f' % float(torch.cuda.memory_cached() / BYTES_IN_GB)) t = time.time() # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch try: batch = next(self._train_iterator) except StopIteration: self._create_train_iterator() batch = next(self._train_iterator) if self.device_count == 1: batch = self._batch_to_device(batch) # Compute loss if self.top_k == 'None': _, _, loss = self._compute_batch( batch, metrics_with_states) else: loss = self._compute_kl_loss(batch) print('LOSS') print(loss) accumulated_loss += loss.item() / self.batches_per_iter loss.backward() # try: # loss.backward() # except RuntimeError: # torch.cuda.empty_cache() # print('EMPTIED CACHE FOR LOSS') # continue # Log loss global_step = (self.iter_per_step * self._step) + i self.beta = self.get_beta(global_step) # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{log_prefix}/Loss', accumulated_loss, global_step) if self.device_count > 1: log(f'{log_prefix}/Gradient_Norm', self.model.module.gradient_norm, global_step) log(f'{log_prefix}/Parameter_Norm', self.model.module.parameter_norm, global_step) else: log(f'{log_prefix}/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{log_prefix}/Parameter_Norm', self.model.parameter_norm, global_step) log(f'{log_prefix}/Beta', self.beta, global_step) # Optimize self.optimizer.step() # Update iter scheduler if self.iter_scheduler is not None: lr = self.optimizer.param_groups[0]['lr'] # type: ignore log(f'{log_prefix}/LR', lr, global_step) self.iter_scheduler.step() # type: ignore # Zero the gradients when exiting a train step self.optimizer.zero_grad() # logging train metrics if self.extra_training_metrics_log_interval > self._last_train_log_step: self._log_metrics(log_prefix, metrics_with_states, global_step) self._last_train_log_step = i print('TOTAL TIME: %f' % (time.time() - t)) if self._last_train_log_step != i: # log again at end of step, if not logged at the end of # step before self._log_metrics(log_prefix, metrics_with_states, global_step)
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() metrics_with_states: List[Tuple] = [ (metric, {}) for metric in self.training_metrics ] self._last_train_log_step = 0 log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log_prefix += 'Training' with torch.enable_grad(): for i in range(self.iter_per_step): # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch try: batch = next(self._train_iterator) except StopIteration: self._create_train_iterator() batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss _, _, loss = self._compute_batch(batch, metrics_with_states) accumulated_loss += loss.item() / self.batches_per_iter loss.backward() # Log loss global_step = (self.iter_per_step * self._step) + i # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{log_prefix}/Loss', accumulated_loss, global_step) log(f'{log_prefix}/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{log_prefix}/Parameter_Norm', self.model.parameter_norm, global_step) # Optimize self.optimizer.step() # Update iter scheduler if self.iter_scheduler is not None: learning_rate = self.iter_scheduler.get_lr()[ 0] # type: ignore log(f'{log_prefix}/LR', learning_rate, global_step) self.iter_scheduler.step() # type: ignore # Zero the gradients when exiting a train step self.optimizer.zero_grad() # logging train metrics if self.extra_training_metrics_log_interval > self._last_train_log_step: self._log_metrics(log_prefix, metrics_with_states, global_step) self._last_train_log_step = i if self._last_train_log_step != i: # log again at end of step, if not logged at the end of # step before self._log_metrics(log_prefix, metrics_with_states, global_step)
def sample(self, data: Sequence[Sequence[torch.Tensor]], n_epochs: int = 1) -> Iterator[Tuple[torch.Tensor, ...]]: if self.translator is None or self.teacher is None: raise ValueError( 'Cannot sample because one of student or teacher is missing!') samp_per_epoch = math.ceil(self.length(data) / self.sample_factor) N = n_epochs * samp_per_epoch beta = 1. if self.scheduler_type == 'linear': get_beta = lambda t: 1. - t / N elif self.scheduler_type == 'exponential': get_beta = lambda t: 200**(-t / N) elif self.scheduler_type == 'reverse_sigmoid': get_beta = lambda t: 1 / (1 + np.exp((t / N - 0.5) * 20)) elif self.scheduler_type == 'ones': get_beta = lambda t: 1. elif self.scheduler_type == 'zeros': get_beta = lambda t: 0. else: raise ValueError('Not implemented!') if len(data) == 0: raise ValueError("No examples provided") collate_fn_p = partial(collate_fn, pad=self.pad) def collate_with_filter(batch): batch = filter( lambda lst: all([len(x) <= self.max_seq_len for x in lst]), batch) return collate_fn_p(batch) sample_batch_size = self.batch_size * self.sample_factor loader = DataLoader( dataset=data, # type: ignore shuffle=self.shuffle, batch_size=sample_batch_size, collate_fn=collate_with_filter, num_workers=self.n_workers, pin_memory=self.pin_memory, drop_last=self.drop_last) for epoch in range(n_epochs): for samp_count, batch in enumerate(loader): batch = [x.clone() for x in batch] src, tgt_context, tgt_words = batch max_seq_len = self.translator.max_seq_len pad_idx = self.translator.tgt_pad_idx tgt_context = pad_to_len(tgt_context, max_seq_len, pad_idx, dim=1) tgt_words = pad_to_len(tgt_words, max_seq_len, pad_idx, dim=1) beta = get_beta(epoch * samp_per_epoch + samp_count) dist = Categorical(torch.tensor([beta, 1 - beta])) samp_mask = (dist.sample((src.size(0), )) == 1) if torch.sum(samp_mask).item() > 0: samp_src = src[samp_mask].to(self.device) self.translator.model.eval() with torch.no_grad(): samp_tgt_context = self.translator(samp_src) samp_tgt_logits = self.teacher(samp_src, samp_tgt_context) _, samp_tgt_words = samp_tgt_logits.max(dim=-1) # Pad words correctly eos_idx = self.translator.tgt_eos_idx pad_idx = self.translator.tgt_pad_idx eos_mask = (samp_tgt_context == eos_idx) eos_mask |= (samp_tgt_context == pad_idx) samp_tgt_words[eos_mask] = pad_idx samp_src = samp_src.cpu() samp_tgt_context = samp_tgt_context.cpu() samp_tgt_words = samp_tgt_words.cpu() tgt_context[samp_mask] = samp_tgt_context tgt_words[samp_mask] = samp_tgt_words max_len = (tgt_words != pad_idx).sum(dim=1).max().item() tgt_context = tgt_context[:, :max_len] tgt_words = tgt_words[:, :max_len] self.translator.model.train() B = self.batch_size for i in range(self.sample_factor): step = self.sample_factor * (epoch * samp_per_epoch + samp_count) + i log('Training/Beta', beta, step) src_slice = src[i * B:(i + 1) * B] tgt_context_slice = tgt_context[i * B:(i + 1) * B] tgt_words_slice = tgt_words[i * B:(i + 1) * B] yield src_slice, tgt_context_slice, tgt_words_slice
def _eval_step(self) -> None: super()._eval_step() log_masks(self.model, self.hard_concrete_masks, self._step) num_parameters = get_num_params(self.hard_concrete_modules, train=False) num_non_prunable = self.init_num_params - self.max_prunable total_num_params = int(num_parameters) + num_non_prunable relative_sparsity = 1. - (num_parameters / self.max_prunable) log("Num_Params", int(num_parameters), self._step) log("Relative_sparsity", float(relative_sparsity), self._step) log("True_sparsity", 1. - total_num_params / self.init_num_params, self._step) log("Total_num_params", total_num_params, self._step) log('LambdaLR', self.optimizer_lambdas.param_groups[0]['lr'], self._step) # type: ignore log('AlphaLR', self.optimizer_alphas.param_groups[0]['lr'], self._step) # type: ignore
def __init__(self, dataset: Dataset, train_sampler: Sampler, val_sampler: Sampler, model: Module, loss_fn: Metric, metric_fn: Metric, optimizer, scheduler=None, device: Optional[str] = None, max_steps: int = 10, epoch_per_step: float = 1.0, iter_per_step: Optional[int] = None, batches_per_iter: int = 1, lower_is_better: bool = False, max_grad_norm: Optional[float] = None, max_grad_abs_val: Optional[float] = None, extra_validation_metrics: Optional[List[Metric]] = None, lr_warmup: int = 100, model_dim: int = 512, iter_before_pruning: int = 0, init_mean: float = 0.5, init_std: float = 0.01, alphas_lr: float = 0.001, lambdas_lr: float = 1.0, target_sparsity: float = 0.8, target_sparsity_warmup: int = 80000, weight_decay: float = 0) -> None: """Initialize the Trainer. Parameters ---------- dataset: Dataset The dataset containing the first N columns of data for the student model, and the last N columns for the target. train_sampler : Sampler The sampler to use over training examples val_sampler : Sampler The sampler to use over validation examples model : Module The model to train optimizer : torch.optim.Optimizer The optimizer to use scheduler : torch.optim.lr_scheduler._LRScheduler, optional An optional learning rate scheduler device: str, optional The device to use in the computation. Only used by compile. max_steps : int, optional The maximum number of training steps to run epoch_per_step : float, optional Fraction of an epoch to perform in a single training step (i.e before a checkpoint.) Defaults to 1. Overriden by `iter_per_step`, if given. iter_per_step : int, optional Number of iterations to perform in a single training step. Overrides `epoch_per_step` if given. batches_per_iter : int, optional Number of batches to pass through the model before calling optimizer.step. Requires the sampler to have drop_last set to True. (default set to 1 so optimizer.step is called after every batch) lower_is_better : bool, optional If true, the lowest dev metric is considered best, otherwise the highest. Defaults to False. max_grad_norm : float, optional Maximum Euclidean norm of gradient after clipping. max_grad_abs_val: float, optional Maximum absolute value of all gradient vector components after clipping. extra_validation_metrics: Optional[List[Metric]] A list with extra metrics to show in each step but which don't guide the training procedures (i.e model selection through early stopping) """ super().__init__(dataset, train_sampler, # type: ignore val_sampler, model, loss_fn, metric_fn, optimizer, scheduler, device, max_steps, epoch_per_step, iter_per_step, batches_per_iter, lower_is_better, max_grad_norm, max_grad_abs_val, extra_validation_metrics) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.lambda_1 = nn.Parameter(torch.tensor(0.).to(device)) # type: ignore self.lambda_2 = nn.Parameter(torch.tensor(0.).to(device)) # type: ignore self.optimizer_lambdas = Adam([self.lambda_1, self.lambda_2], weight_decay=0) self.optimizer_lambdas.param_groups[0]['lr'] = -lambdas_lr # type: ignore self.lambdas_scheduler = NoamScheduler(self.optimizer_lambdas, warmup=lr_warmup, d_model=model_dim) self.iter_before_pruning = iter_before_pruning self.init_num_params = sum([len(p.view(-1)) for p in self.model.parameters()]) make_hard_concrete(self.model, in_place=True, init_mean=init_mean, init_std=init_std) self.hard_concrete_modules = get_hardconcrete_proj_linear_modules(self.model) self.hard_concrete_masks = get_hardconcrete_modules(self.model) self.max_prunable = get_num_prunable_params(self.hard_concrete_modules) self.target_sparsity = max(min(target_sparsity, 1.0), 0.0) self.target_sparsity_warmup = target_sparsity_warmup self.model.to(device) model_params = (p for n, p in self.model.named_parameters() if 'log_alpha' not in n) alpha_params = (p for n, p in self.model.named_parameters() if 'log_alpha' in n) self.optimizer_alphas = Adam(alpha_params, lr=alphas_lr) # type: ignore self.alphas_scheduler = NoamScheduler(self.optimizer_alphas, warmup=lr_warmup, d_model=model_dim) self.optimizer = Adam(model_params, lr=self.optimizer.param_groups[0]['lr'], # type: ignore weight_decay=weight_decay) self.lr_scheduler = NoamScheduler(self.optimizer, warmup=lr_warmup, d_model=model_dim) log('Total_params', int(self.init_num_params), 0) log('Max_prunable', int(self.max_prunable), 0)
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" with torch.enable_grad(): for i in range(self.iter_per_step): # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() self.optimizer_alphas.zero_grad() self.optimizer_lambdas.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss loss = self._compute_loss(batch) / self.batches_per_iter accumulated_loss += loss.item() loss.backward() # Log loss global_step = (self.iter_per_step * self._step) + i # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step) log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step) if global_step >= self.iter_before_pruning: pruning_step = global_step - self.iter_before_pruning num_parameters = get_num_params(self.hard_concrete_modules, train=True) expected_sparsity = 1. - (num_parameters / self.max_prunable) if self.target_sparsity_warmup > 0: factor = min(1.0, pruning_step / self.target_sparsity_warmup) target_sparsity = self.target_sparsity * factor else: target_sparsity = self.target_sparsity lagrangian_loss = self.lambda_1 * (target_sparsity - expected_sparsity) lagrangian_loss += self.lambda_2 * (target_sparsity - expected_sparsity) ** 2 lagrangian_loss.backward() log("Expected_sparsity", float(expected_sparsity), global_step) log("Lagrangian_loss", lagrangian_loss.item(), global_step) log("Target_sparsity", target_sparsity, global_step) log("lambda_1", self.lambda_1.item(), global_step) log("lambda_2", self.lambda_2.item(), global_step) self.optimizer_lambdas.step() self.lambdas_scheduler.step(pruning_step) self.optimizer_alphas.step() self.alphas_scheduler.step(pruning_step) # Optimize self.optimizer.step() self.lr_scheduler.step(global_step) # Zero the gradients when exiting a train step self.optimizer.zero_grad() self.optimizer_lambdas.zero_grad() self.optimizer_alphas.zero_grad()
def run(self, block_name: str = None) -> bool: self.model.to(self.device) self.model.eval() if self.teacher is not None: self.teacher.to(self.device) self.teacher.eval() with torch.no_grad(): preds, targets = [], [] if self.save_preds: sources = [] if self.teacher is not None: teacher_preds = [] for batch in self._eval_iterator: batch = [t.to(self.device) for t in batch] pred = self.model(batch[0], gen_style=self.gen_style) if self.save_preds: source = batch[0] if source.size(1) < 50: extra = torch.zeros( (source.size(0), 50 - source.size(1))).long().to(self.device) source = torch.cat([source, extra], dim=1) sources.append(source) if self.teacher is not None: student_context = pred[:, :-1].to(self.device) _, teacher_pred = self.teacher( batch[0], student_context).max(dim=2) teacher_pred[student_context == self.model.eos_idx] = self.model.pad_idx teacher_pred[student_context == self.model.pad_idx] = self.model.pad_idx teacher_preds.append(teacher_pred.cpu()) target = batch[1] if target.size(1) < 50: extra = torch.zeros( (target.size(0), 50 - target.size(1))).long().to(self.device) target = torch.cat([target, extra], dim=1) preds.append(pred.cpu()) targets.append(target.cpu()) preds = torch.cat(preds, dim=0) # type: ignore if self.save_preds: sources = torch.cat(sources, dim=0) if self.teacher is not None: teacher_preds = torch.cat(teacher_preds, dim=0) self.decode_data = (sources, preds[:, :-1], teacher_preds) else: self.decode_data = (sources, preds[:, :-1], preds[:, 1:]) targets = torch.cat(targets, dim=0) # type: ignore if self.save_targets: self.targets = targets self.eval_metric = self.metric_fn(preds, targets).item() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log( f'{tb_prefix}Eval {self.metric_fn}', # type: ignore self.eval_metric, global_step=0) # type: ignore return False