Ejemplo n.º 1
0
    def eval_model(self, params: spec.ParameterContainer,
                   model_state: spec.ModelAuxiliaryState,
                   rng: spec.RandomState, data_dir: str):
        """Run a full evaluation of the model."""

        params.eval()
        total_error = 0.0
        total_length = 0.0
        with torch.no_grad():
            for (_, features, transcripts,
                 input_lengths) in self._valid_loader:
                features = features.float().to(self._device)
                features = features.transpose(1, 2).unsqueeze(1)
                transcripts = transcripts.long().to(self._device)
                input_lengths = input_lengths.int()

                log_y, _ = params(features, input_lengths, transcripts)

                out, _, _, seq_lens = self._decoder.decode(
                    torch.exp(log_y).detach().cpu(), input_lengths)
                for hyp, trn, length in zip(out, transcripts,
                                            seq_lens):  # iterate batch
                    best_hyp = hyp[0, :length[0]]
                    hh = "".join(
                        [self._rev_label_dict[i.item()] for i in best_hyp])
                    t = trn.detach().cpu().tolist()
                    t = [ll for ll in t if ll != 0]
                    tlength = len(t)
                    tt = "".join([self._rev_label_dict[i] for i in t])
                    error = Levenshtein.distance(tt, hh)
                    total_error += error
                    total_length += tlength

        return total_error / total_length
Ejemplo n.º 2
0
    def model_fn(
        self, params: spec.ParameterContainer,
        augmented_and_preprocessed_input_batch: spec.Tensor,
        model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode,
        rng: spec.RandomState, update_batch_norm: bool
    ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
        del model_state
        del rng
        del update_batch_norm

        params.train(mode == spec.ForwardPassMode.TRAIN)
        features, transcripts, input_lengths = augmented_and_preprocessed_input_batch
        log_y, output_lengths = params(features, input_lengths, transcripts)

        return (log_y.transpose(0, 1), output_lengths), None
Ejemplo n.º 3
0
def init_optimizer_state(workload: spec.Workload,
                         model_params: spec.ParameterContainer,
                         model_state: spec.ModelAuxiliaryState,
                         hyperparameters: spec.Hyperparameters,
                         rng: spec.RandomState) -> spec.OptimizerState:
    del workload
    del model_state
    del rng

    base_lr = hyperparameters.learning_rate * get_batch_size('imagenet') / 256.
    optimizer_state = {
        'optimizer':
        torch.optim.SGD(model_params.parameters(),
                        lr=base_lr,
                        momentum=hyperparameters.momentum,
                        weight_decay=hyperparameters.l2)
    }

    scheduler1 = LinearLR(optimizer_state['optimizer'],
                          start_factor=1e-5,
                          end_factor=1.,
                          total_iters=hyperparameters.warmup_epochs)
    cosine_epochs = max(
        hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1)
    scheduler2 = CosineAnnealingLR(optimizer_state['optimizer'],
                                   T_max=cosine_epochs)

    optimizer_state['scheduler'] = SequentialLR(
        optimizer_state['optimizer'],
        schedulers=[scheduler1, scheduler2],
        milestones=[hyperparameters.warmup_epochs])

    return optimizer_state
Ejemplo n.º 4
0
  def _eval_model_on_split(self,
                           split: str,
                           num_examples: int,
                           global_batch_size: int,
                           params: spec.ParameterContainer,
                           model_state: spec.ModelAuxiliaryState,
                           rng: spec.RandomState,
                           data_dir: str):
    del model_state
    if split not in self._eval_iters:
      data_loader = self.build_input_queue(rng,
                                           split,
                                           data_dir,
                                           global_batch_size)
      # Note that this saves the entire dataset split in memory.
      self._eval_iters[split] = itertools.cycle(data_loader)
    num_batches = int(math.ceil(num_examples / global_batch_size))
    params.eval()
    total_error = 0.0
    total_length = 0.0
    with torch.no_grad():
      for (bi, batch) in enumerate(self._eval_iters[split]):
        if bi > num_batches:
          break
        features = batch['features'].float().to(DEVICE)
        features = features.transpose(1, 2).unsqueeze(1)
        transcripts = batch['transcripts'].long().to(DEVICE)
        input_lengths = batch['input_lengths'].int()

        log_y, _ = params(features, input_lengths, transcripts)

        out, _, _, seq_lens = self._decoder.decode(
            torch.exp(log_y).detach().cpu(), input_lengths)
        for hyp, trn, length in zip(out, transcripts,
                                    seq_lens):  # iterate batch
          best_hyp = hyp[0, :length[0]]
          hh = "".join([self._rev_label_dict[i.item()] for i in best_hyp])
          t = trn.detach().cpu().tolist()
          t = [ll for ll in t if ll != 0]
          tlength = len(t)
          tt = "".join([self._rev_label_dict[i] for i in t])
          error = Levenshtein.distance(tt, hh)
          total_error += error
          total_length += tlength

    wer = total_error / total_length
    return {'word_error_rate': wer}
def update_params(
        workload: spec.Workload,
        current_param_container: spec.ParameterContainer,
        current_params_types: spec.ParameterTypeTree,
        model_state: spec.ModelAuxiliaryState,
        hyperparameters: spec.Hyperparamters,
        input_batch: spec.Tensor,
        label_batch: spec.Tensor,
        # This will define the output activation via `output_activation_fn`.
        loss_type: spec.LossType,
        optimizer_state: spec.OptimizerState,
        eval_results: List[Tuple[int, float]],
        global_step: int,
        rng: spec.RandomState) -> spec.UpdateReturn:
    """Return (updated_optimizer_state, updated_params)."""
    del current_params_types
    del hyperparameters
    del loss_type
    del eval_results
    input_batch, label_batch = (workload.preprocess_for_train(
        input_batch, label_batch, None, None, None))

    current_model = current_param_container
    current_param_container.train()
    optimizer_state['optimizer'].zero_grad()

    logits_batch, new_model_state = workload.model_fn(
        params=current_model,
        input_batch=input_batch,
        model_state=model_state,
        mode=spec.ForwardPassMode.TRAIN,
        rng=rng,
        update_batch_norm=True)

    loss = workload.loss_fn(label_batch=label_batch,
                            logits_batch=logits_batch).mean()

    loss.backward()
    optimizer_state['optimizer'].step()

    steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet')
    if (global_step + 1) % steps_per_epoch == 0:
        optimizer_state['scheduler'].step()

    return (optimizer_state, current_param_container, new_model_state)
Ejemplo n.º 6
0
def init_optimizer_state(workload: spec.Workload,
                         model_params: spec.ParameterContainer,
                         model_state: spec.ModelAuxiliaryState,
                         hyperparameters: spec.Hyperparamters,
                         rng: spec.RandomState) -> spec.OptimizerState:
    del workload
    del model_state
    del rng

    optimizer = torch.optim.Adam(model_params.parameters(),
                                 hyperparameters.learning_rate)
    return optimizer
Ejemplo n.º 7
0
def update_params(workload: spec.Workload,
                  current_param_container: spec.ParameterContainer,
                  current_params_types: spec.ParameterTypeTree,
                  model_state: spec.ModelAuxiliaryState,
                  hyperparameters: spec.Hyperparameters,
                  batch: Dict[str, spec.Tensor], loss_type: spec.LossType,
                  optimizer_state: spec.OptimizerState,
                  eval_results: List[Tuple[int, float]], global_step: int,
                  rng: spec.RandomState) -> spec.UpdateReturn:
    """Return (updated_optimizer_state, updated_params)."""
    del current_params_types
    del eval_results
    del loss_type
    del hyperparameters

    current_model = current_param_container
    current_param_container.train()
    optimizer = optimizer_state['optimizer']
    optimizer.zero_grad()

    logits, _ = workload.model_fn(params=current_model,
                                  augmented_and_preprocessed_input_batch=batch,
                                  model_state=model_state,
                                  mode=spec.ForwardPassMode.TRAIN,
                                  rng=rng,
                                  update_batch_norm=False)

    targets = batch['targets']
    weights = torch.where(targets > 0, 1.0, 0.0)
    loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum()
    loss.backward()

    lr = optimizer_state['scheduler'](global_step).item()
    for g in optimizer.param_groups:
        g['lr'] = lr
    optimizer.step()

    return (optimizer_state, current_param_container, None)
Ejemplo n.º 8
0
  def model_fn(
      self,
      params: spec.ParameterContainer,
      augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
      model_state: spec.ModelAuxiliaryState,
      mode: spec.ForwardPassMode,
      rng: spec.RandomState,
      update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
    del model_state
    del rng
    del update_batch_norm
    features = augmented_and_preprocessed_input_batch['features']
    transcripts = augmented_and_preprocessed_input_batch['transcripts']
    input_lengths = augmented_and_preprocessed_input_batch['input_lengths']
    features = features.float().to(DEVICE)
    features = features.transpose(1, 2).unsqueeze(1)
    transcripts = transcripts.long().to(DEVICE)
    input_lengths = input_lengths.long().to(DEVICE)

    params.train(mode == spec.ForwardPassMode.TRAIN)
    log_y, output_lengths = params(features, input_lengths, transcripts)

    return (log_y.transpose(0, 1), output_lengths), None
Ejemplo n.º 9
0
def init_optimizer_state(workload: spec.Workload,
                         model_params: spec.ParameterContainer,
                         model_state: spec.ModelAuxiliaryState,
                         hyperparameters: spec.Hyperparameters,
                         rng: spec.RandomState) -> spec.OptimizerState:
    del workload
    del model_state
    del rng

    optimizer_state = {
        'optimizer':
        torch.optim.Adam(model_params.parameters(),
                         lr=hyperparameters.learning_rate,
                         betas=(1.0 - hyperparameters.one_minus_beta_1, 0.98),
                         eps=hyperparameters.epsilon)
    }

    optimizer_state['scheduler'] = create_learning_rate_scheduler(
        base_learning_rate=hyperparameters.learning_rate)
    return optimizer_state