示例#1
0
def update_params(
        workload: spec.Workload,
        current_param_container: spec.ParameterContainer,
        current_params_types: spec.ParameterTypeTree,
        model_state: spec.ModelAuxiliaryState,
        hyperparameters: spec.Hyperparameters,
        batch: Dict[str, spec.Tensor],
        # This will define the output activation via `output_activation_fn`.
        loss_type: spec.LossType,
        optimizer_state: spec.OptimizerState,
        eval_results: List[Tuple[int, float]],
        global_step: int,
        rng: spec.RandomState) -> spec.UpdateReturn:
    """Return (updated_optimizer_state, updated_params)."""
    del current_params_types
    del eval_results
    del global_step
    del model_state
    del loss_type
    del hyperparameters

    optimizer_state.zero_grad()

    (log_y, output_lengths), _ = workload.model_fn(current_param_container,
                                                   batch, None,
                                                   spec.ForwardPassMode.TRAIN,
                                                   rng, False)

    train_ctc_loss = torch.mean(
        workload.loss_fn(batch, (log_y, output_lengths)))
    train_ctc_loss.backward()
    optimizer_state.step()

    return optimizer_state, current_param_container, None
示例#2
0
def update_params(workload: spec.Workload,
                  current_param_container: spec.ParameterContainer,
                  current_params_types: spec.ParameterTypeTree,
                  model_state: spec.ModelAuxiliaryState,
                  hyperparameters: spec.Hyperparamters,
                  input_batch: spec.Tensor,
                  label_batch: spec.Tensor,
                  loss_type: spec.LossType,
                  optimizer_state: spec.OptimizerState,
                  eval_results: List[Tuple[int, float]],
                  global_step: int,
                  rng: spec.RandomState) -> spec.UpdateReturn:
  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
  batch = {'image': input_batch, 'label': label_batch}
  optimizer_state, opt_update_fn = optimizer_state
  new_model_state, new_optimizer_state, new_params = pmapped_train_step(
      workload, opt_update_fn, model_state, optimizer_state,
      current_param_container, hyperparameters, batch, rng)

  steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet')
  if (global_step + 1) % steps_per_epoch == 0:
    # sync batch statistics across replicas once per epoch
    new_model_state = workload.sync_batch_stats(new_model_state)

  return (new_optimizer_state, opt_update_fn), new_params, new_model_state
def update_params(
        workload: spec.Workload,
        current_param_container: spec.ParameterContainer,
        current_params_types: spec.ParameterTypeTree,
        model_state: spec.ModelAuxiliaryState,
        hyperparameters: spec.Hyperparamters,
        input_batch: spec.Tensor,
        label_batch: spec.Tensor,
        # This will define the output activation via `output_activation_fn`.
        loss_type: spec.LossType,
        optimizer_state: spec.OptimizerState,
        eval_results: List[Tuple[int, float]],
        global_step: int,
        rng: spec.RandomState) -> spec.UpdateReturn:
    """Return (updated_optimizer_state, updated_params)."""
    del current_params_types
    del hyperparameters
    del loss_type
    del eval_results
    input_batch, label_batch = (workload.preprocess_for_train(
        input_batch, label_batch, None, None, None))

    current_model = current_param_container
    current_param_container.train()
    optimizer_state['optimizer'].zero_grad()

    logits_batch, new_model_state = workload.model_fn(
        params=current_model,
        input_batch=input_batch,
        model_state=model_state,
        mode=spec.ForwardPassMode.TRAIN,
        rng=rng,
        update_batch_norm=True)

    loss = workload.loss_fn(label_batch=label_batch,
                            logits_batch=logits_batch).mean()

    loss.backward()
    optimizer_state['optimizer'].step()

    steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet')
    if (global_step + 1) % steps_per_epoch == 0:
        optimizer_state['scheduler'].step()

    return (optimizer_state, current_param_container, new_model_state)
def update_params(
        workload: spec.Workload,
        current_param_container: spec.ParameterContainer,
        current_params_types: spec.ParameterTypeTree,
        model_state: spec.ModelAuxiliaryState,
        hyperparameters: spec.Hyperparamters,
        input_batch: spec.Tensor,
        label_batch: spec.Tensor,
        # This will define the output activation via `output_activation_fn`.
        loss_type: spec.LossType,
        optimizer_state: spec.OptimizerState,
        eval_results: List[Tuple[int, float]],
        global_step: int,
        rng: spec.RandomState) -> spec.UpdateReturn:
    """Return (updated_optimizer_state, updated_params)."""
    del current_params_types
    del eval_results
    del global_step
    del model_state
    del loss_type
    del hyperparameters
    del label_batch

    _, features, transcripts, input_lengths = input_batch
    features = features.float().to(device)
    features = features.transpose(1, 2).unsqueeze(1)
    transcripts = transcripts.long().to(device)
    input_lengths = input_lengths.long().to(device)

    optimizer_state.zero_grad()

    (log_y, output_lengths), _ = workload.model_fn(
        current_param_container, (features, transcripts, input_lengths), None,
        spec.ForwardPassMode.TRAIN, rng, False)

    train_ctc_loss = torch.mean(
        workload.loss_fn(transcripts, (log_y, output_lengths)))

    train_ctc_loss.backward()
    optimizer_state.step()

    return optimizer_state, current_param_container, None
示例#5
0
def update_params(workload: spec.Workload,
                  current_param_container: spec.ParameterContainer,
                  current_params_types: spec.ParameterTypeTree,
                  model_state: spec.ModelAuxiliaryState,
                  hyperparameters: spec.Hyperparameters,
                  batch: Dict[str, spec.Tensor], loss_type: spec.LossType,
                  optimizer_state: spec.OptimizerState,
                  eval_results: List[Tuple[int, float]], global_step: int,
                  rng: spec.RandomState) -> spec.UpdateReturn:
    """Return (updated_optimizer_state, updated_params)."""
    del current_params_types
    del eval_results
    del loss_type
    del hyperparameters

    current_model = current_param_container
    current_param_container.train()
    optimizer = optimizer_state['optimizer']
    optimizer.zero_grad()

    logits, _ = workload.model_fn(params=current_model,
                                  augmented_and_preprocessed_input_batch=batch,
                                  model_state=model_state,
                                  mode=spec.ForwardPassMode.TRAIN,
                                  rng=rng,
                                  update_batch_norm=False)

    targets = batch['targets']
    weights = torch.where(targets > 0, 1.0, 0.0)
    loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum()
    loss.backward()

    lr = optimizer_state['scheduler'](global_step).item()
    for g in optimizer.param_groups:
        g['lr'] = lr
    optimizer.step()

    return (optimizer_state, current_param_container, None)
示例#6
0
def train_once(workload: spec.Workload, batch_size: int, data_dir: str,
               init_optimizer_state: spec.InitOptimizerFn,
               update_params: spec.UpdateParamsFn,
               data_selection: spec.DataSelectionFn,
               hyperparameters: Optional[spec.Hyperparamters],
               rng: spec.RandomState) -> Tuple[spec.Timing, spec.Steps]:
    data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)

    # Workload setup.
    logging.info('Initializing dataset.')
    input_queue = workload.build_input_queue(data_rng,
                                             'train',
                                             data_dir=data_dir,
                                             batch_size=batch_size)
    logging.info('Initializing model.')
    model_params, model_state = workload.init_model_fn(model_init_rng)
    logging.info('Initializing optimizer.')
    optimizer_state = init_optimizer_state(workload, model_params, model_state,
                                           hyperparameters, opt_init_rng)

    # Bookkeeping.
    goal_reached = False
    is_time_remaining = True
    last_eval_time = 0
    accumulated_submission_time = 0
    eval_results = []
    global_step = 0
    training_complete = False
    global_start_time = time.time()

    logging.info('Starting training loop.')
    while (is_time_remaining and not goal_reached and not training_complete):
        step_rng = prng.fold_in(rng, global_step)
        data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)
        start_time = time.time()
        selected_train_input_batch, selected_train_label_batch = data_selection(
            workload, input_queue, optimizer_state, model_params,
            hyperparameters, global_step, data_select_rng)
        try:
            optimizer_state, model_params, model_state = update_params(
                workload=workload,
                current_param_container=model_params,
                current_params_types=workload.model_params_types(),
                model_state=model_state,
                hyperparameters=hyperparameters,
                input_batch=selected_train_input_batch,
                label_batch=selected_train_label_batch,
                loss_type=workload.loss_type,
                optimizer_state=optimizer_state,
                eval_results=eval_results,
                global_step=global_step,
                rng=update_rng)
        except spec.TrainingCompleteError:
            training_complete = True
        global_step += 1
        current_time = time.time()
        accumulated_submission_time += current_time - start_time
        is_time_remaining = (accumulated_submission_time <
                             workload.max_allowed_runtime_sec)
        # Check if submission is eligible for an untimed eval.
        if (current_time - last_eval_time >= workload.eval_period_time_sec
                or training_complete):
            latest_eval_result = workload.eval_model(model_params, model_state,
                                                     eval_rng, data_dir)
            logging.info(
                f'{current_time - global_start_time:.2f}s\t{global_step}'
                f'\t{latest_eval_result}')
            last_eval_time = current_time
            eval_results.append((global_step, latest_eval_result))
            goal_reached = workload.has_reached_goal(latest_eval_result)
    metrics = {'eval_results': eval_results, 'global_step': global_step}
    return accumulated_submission_time, metrics