def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, batch: Dict[str, spec.Tensor], # This will define the output activation via `output_activation_fn`. loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del eval_results del global_step del model_state del loss_type del hyperparameters optimizer_state.zero_grad() (log_y, output_lengths), _ = workload.model_fn(current_param_container, batch, None, spec.ForwardPassMode.TRAIN, rng, False) train_ctc_loss = torch.mean( workload.loss_fn(batch, (log_y, output_lengths))) train_ctc_loss.backward() optimizer_state.step() return optimizer_state, current_param_container, None
def update_params(workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, input_batch: spec.Tensor, label_batch: spec.Tensor, loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" batch = {'image': input_batch, 'label': label_batch} optimizer_state, opt_update_fn = optimizer_state new_model_state, new_optimizer_state, new_params = pmapped_train_step( workload, opt_update_fn, model_state, optimizer_state, current_param_container, hyperparameters, batch, rng) steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet') if (global_step + 1) % steps_per_epoch == 0: # sync batch statistics across replicas once per epoch new_model_state = workload.sync_batch_stats(new_model_state) return (new_optimizer_state, opt_update_fn), new_params, new_model_state
def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, input_batch: spec.Tensor, label_batch: spec.Tensor, # This will define the output activation via `output_activation_fn`. loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters del loss_type del eval_results input_batch, label_batch = (workload.preprocess_for_train( input_batch, label_batch, None, None, None)) current_model = current_param_container current_param_container.train() optimizer_state['optimizer'].zero_grad() logits_batch, new_model_state = workload.model_fn( params=current_model, input_batch=input_batch, model_state=model_state, mode=spec.ForwardPassMode.TRAIN, rng=rng, update_batch_norm=True) loss = workload.loss_fn(label_batch=label_batch, logits_batch=logits_batch).mean() loss.backward() optimizer_state['optimizer'].step() steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet') if (global_step + 1) % steps_per_epoch == 0: optimizer_state['scheduler'].step() return (optimizer_state, current_param_container, new_model_state)
def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparamters, input_batch: spec.Tensor, label_batch: spec.Tensor, # This will define the output activation via `output_activation_fn`. loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del eval_results del global_step del model_state del loss_type del hyperparameters del label_batch _, features, transcripts, input_lengths = input_batch features = features.float().to(device) features = features.transpose(1, 2).unsqueeze(1) transcripts = transcripts.long().to(device) input_lengths = input_lengths.long().to(device) optimizer_state.zero_grad() (log_y, output_lengths), _ = workload.model_fn( current_param_container, (features, transcripts, input_lengths), None, spec.ForwardPassMode.TRAIN, rng, False) train_ctc_loss = torch.mean( workload.loss_fn(transcripts, (log_y, output_lengths))) train_ctc_loss.backward() optimizer_state.step() return optimizer_state, current_param_container, None
def update_params(workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del eval_results del loss_type del hyperparameters current_model = current_param_container current_param_container.train() optimizer = optimizer_state['optimizer'] optimizer.zero_grad() logits, _ = workload.model_fn(params=current_model, augmented_and_preprocessed_input_batch=batch, model_state=model_state, mode=spec.ForwardPassMode.TRAIN, rng=rng, update_batch_norm=False) targets = batch['targets'] weights = torch.where(targets > 0, 1.0, 0.0) loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum() loss.backward() lr = optimizer_state['scheduler'](global_step).item() for g in optimizer.param_groups: g['lr'] = lr optimizer.step() return (optimizer_state, current_param_container, None)
def train_once(workload: spec.Workload, batch_size: int, data_dir: str, init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparamters], rng: spec.RandomState) -> Tuple[spec.Timing, spec.Steps]: data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) # Workload setup. logging.info('Initializing dataset.') input_queue = workload.build_input_queue(data_rng, 'train', data_dir=data_dir, batch_size=batch_size) logging.info('Initializing model.') model_params, model_state = workload.init_model_fn(model_init_rng) logging.info('Initializing optimizer.') optimizer_state = init_optimizer_state(workload, model_params, model_state, hyperparameters, opt_init_rng) # Bookkeeping. goal_reached = False is_time_remaining = True last_eval_time = 0 accumulated_submission_time = 0 eval_results = [] global_step = 0 training_complete = False global_start_time = time.time() logging.info('Starting training loop.') while (is_time_remaining and not goal_reached and not training_complete): step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) start_time = time.time() selected_train_input_batch, selected_train_label_batch = data_selection( workload, input_queue, optimizer_state, model_params, hyperparameters, global_step, data_select_rng) try: optimizer_state, model_params, model_state = update_params( workload=workload, current_param_container=model_params, current_params_types=workload.model_params_types(), model_state=model_state, hyperparameters=hyperparameters, input_batch=selected_train_input_batch, label_batch=selected_train_label_batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, eval_results=eval_results, global_step=global_step, rng=update_rng) except spec.TrainingCompleteError: training_complete = True global_step += 1 current_time = time.time() accumulated_submission_time += current_time - start_time is_time_remaining = (accumulated_submission_time < workload.max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. if (current_time - last_eval_time >= workload.eval_period_time_sec or training_complete): latest_eval_result = workload.eval_model(model_params, model_state, eval_rng, data_dir) logging.info( f'{current_time - global_start_time:.2f}s\t{global_step}' f'\t{latest_eval_result}') last_eval_time = current_time eval_results.append((global_step, latest_eval_result)) goal_reached = workload.has_reached_goal(latest_eval_result) metrics = {'eval_results': eval_results, 'global_step': global_step} return accumulated_submission_time, metrics