def eval_model(self, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str): """Run a full evaluation of the model.""" params.eval() total_error = 0.0 total_length = 0.0 with torch.no_grad(): for (_, features, transcripts, input_lengths) in self._valid_loader: features = features.float().to(self._device) features = features.transpose(1, 2).unsqueeze(1) transcripts = transcripts.long().to(self._device) input_lengths = input_lengths.int() log_y, _ = params(features, input_lengths, transcripts) out, _, _, seq_lens = self._decoder.decode( torch.exp(log_y).detach().cpu(), input_lengths) for hyp, trn, length in zip(out, transcripts, seq_lens): # iterate batch best_hyp = hyp[0, :length[0]] hh = "".join( [self._rev_label_dict[i.item()] for i in best_hyp]) t = trn.detach().cpu().tolist() t = [ll for ll in t if ll != 0] tlength = len(t) tt = "".join([self._rev_label_dict[i] for i in t]) error = Levenshtein.distance(tt, hh) total_error += error total_length += tlength return total_error / total_length
def model_fn( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: spec.Tensor, model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm params.train(mode == spec.ForwardPassMode.TRAIN) features, transcripts, input_lengths = augmented_and_preprocessed_input_batch log_y, output_lengths = params(features, input_lengths, transcripts) return (log_y.transpose(0, 1), output_lengths), None
def init_optimizer_state(workload: spec.Workload, model_params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, rng: spec.RandomState) -> spec.OptimizerState: del workload del model_state del rng base_lr = hyperparameters.learning_rate * get_batch_size('imagenet') / 256. optimizer_state = { 'optimizer': torch.optim.SGD(model_params.parameters(), lr=base_lr, momentum=hyperparameters.momentum, weight_decay=hyperparameters.l2) } scheduler1 = LinearLR(optimizer_state['optimizer'], start_factor=1e-5, end_factor=1., total_iters=hyperparameters.warmup_epochs) cosine_epochs = max( hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) scheduler2 = CosineAnnealingLR(optimizer_state['optimizer'], T_max=cosine_epochs) optimizer_state['scheduler'] = SequentialLR( optimizer_state['optimizer'], schedulers=[scheduler1, scheduler2], milestones=[hyperparameters.warmup_epochs]) return optimizer_state
def _eval_model_on_split(self, split: str, num_examples: int, global_batch_size: int, params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, rng: spec.RandomState, data_dir: str): del model_state if split not in self._eval_iters: data_loader = self.build_input_queue(rng, split, data_dir, global_batch_size) # Note that this saves the entire dataset split in memory. self._eval_iters[split] = itertools.cycle(data_loader) num_batches = int(math.ceil(num_examples / global_batch_size)) params.eval() total_error = 0.0 total_length = 0.0 with torch.no_grad(): for (bi, batch) in enumerate(self._eval_iters[split]): if bi > num_batches: break features = batch['features'].float().to(DEVICE) features = features.transpose(1, 2).unsqueeze(1) transcripts = batch['transcripts'].long().to(DEVICE) input_lengths = batch['input_lengths'].int() log_y, _ = params(features, input_lengths, transcripts) out, _, _, seq_lens = self._decoder.decode( torch.exp(log_y).detach().cpu(), input_lengths) for hyp, trn, length in zip(out, transcripts, seq_lens): # iterate batch best_hyp = hyp[0, :length[0]] hh = "".join([self._rev_label_dict[i.item()] for i in best_hyp]) t = trn.detach().cpu().tolist() t = [ll for ll in t if ll != 0] tlength = len(t) tt = "".join([self._rev_label_dict[i] for i in t]) error = Levenshtein.distance(tt, hh) total_error += error total_length += tlength wer = total_error / total_length return {'word_error_rate': wer}
def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, input_batch: spec.Tensor, label_batch: spec.Tensor, # This will define the output activation via `output_activation_fn`. loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters del loss_type del eval_results input_batch, label_batch = (workload.preprocess_for_train( input_batch, label_batch, None, None, None)) current_model = current_param_container current_param_container.train() optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( params=current_model, input_batch=input_batch, model_state=model_state, mode=spec.ForwardPassMode.TRAIN, rng=rng, update_batch_norm=True) loss = workload.loss_fn(label_batch=label_batch, logits_batch=logits_batch).mean() loss.backward() optimizer_state['optimizer'].step() steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet') if (global_step + 1) % steps_per_epoch == 0: optimizer_state['scheduler'].step() return (optimizer_state, current_param_container, new_model_state)
def init_optimizer_state(workload: spec.Workload, model_params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, rng: spec.RandomState) -> spec.OptimizerState: del workload del model_state del rng optimizer = torch.optim.Adam(model_params.parameters(), hyperparameters.learning_rate) return optimizer
def update_params(workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del eval_results del loss_type del hyperparameters current_model = current_param_container current_param_container.train() optimizer = optimizer_state['optimizer'] optimizer.zero_grad() logits, _ = workload.model_fn(params=current_model, augmented_and_preprocessed_input_batch=batch, model_state=model_state, mode=spec.ForwardPassMode.TRAIN, rng=rng, update_batch_norm=False) targets = batch['targets'] weights = torch.where(targets > 0, 1.0, 0.0) loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum() loss.backward() lr = optimizer_state['scheduler'](global_step).item() for g in optimizer.param_groups: g['lr'] = lr optimizer.step() return (optimizer_state, current_param_container, None)
def model_fn( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm features = augmented_and_preprocessed_input_batch['features'] transcripts = augmented_and_preprocessed_input_batch['transcripts'] input_lengths = augmented_and_preprocessed_input_batch['input_lengths'] features = features.float().to(DEVICE) features = features.transpose(1, 2).unsqueeze(1) transcripts = transcripts.long().to(DEVICE) input_lengths = input_lengths.long().to(DEVICE) params.train(mode == spec.ForwardPassMode.TRAIN) log_y, output_lengths = params(features, input_lengths, transcripts) return (log_y.transpose(0, 1), output_lengths), None
def init_optimizer_state(workload: spec.Workload, model_params: spec.ParameterContainer, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, rng: spec.RandomState) -> spec.OptimizerState: del workload del model_state del rng optimizer_state = { 'optimizer': torch.optim.Adam(model_params.parameters(), lr=hyperparameters.learning_rate, betas=(1.0 - hyperparameters.one_minus_beta_1, 0.98), eps=hyperparameters.epsilon) } optimizer_state['scheduler'] = create_learning_rate_scheduler( base_learning_rate=hyperparameters.learning_rate) return optimizer_state