def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
        """
        wrap the forward step in a closure so second order methods work
        """
        # ---------------------------
        # FORWARD (TRAINING STEP + TRAIN STEP END)
        # ---------------------------
        with self.profiler.profile('model_forward'):
            args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
            training_step_output = self.accelerator_backend.training_step(args)
            training_step_output = self.call_hook('training_step_end', training_step_output)

            # ----------------------------
            # PROCESS THE RESULT
            # ----------------------------
            # format and reduce outputs accordingly
            training_step_output_for_epoch_end = training_step_output
            is_result_obj = isinstance(training_step_output, Result)

            # track batch size for weighted average
            if is_result_obj:
                training_step_output.track_batch_size(len(split_batch))

            # don't allow EvalResult in the training_step
            if isinstance(training_step_output, EvalResult):
                raise MisconfigurationException('training_step cannot return EvalResult, '
                                                'use a dict or TrainResult instead')

            # handle regular dicts
            if not is_result_obj:
                training_step_output = self.process_output(training_step_output, train=True)

                training_step_output = AttributeDict(
                    batch_loss=training_step_output[0],
                    pbar_on_batch_end=training_step_output[1],
                    log_metrics=training_step_output[2],
                    callback_metrics=training_step_output[3],
                    hiddens=training_step_output[4],
                )

            # if the user decides to finally reduce things in epoch_end, save raw output without graphs
            if isinstance(training_step_output_for_epoch_end, torch.Tensor):
                training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
            elif is_result_obj:
                training_step_output_for_epoch_end = copy(training_step_output)
                training_step_output_for_epoch_end.detach()
            else:
                training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
        closure_loss = closure_loss / self.accumulate_grad_batches

        # the loss will get scaled for amp. avoid any modifications to it
        untouched_loss = closure_loss.detach().clone()

        # backward pass
        model_ref = self.get_model()
        with self.profiler.profile('model_backward'):
            # scale loss for 16 bit
            if self.precision == 16 and not self.on_tpu:
                closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_backend=self.amp_backend)

                # enter amp context
                if self.amp_backend == AMPType.APEX:
                    self.dev_debugger.track_event('AMP', str(AMPType.APEX))
                    context = closure_loss
                    closure_loss = closure_loss.__enter__()

            # do backward pass
            model_ref.backward(self, closure_loss, optimizer, opt_idx)

            # exit amp context
            if self.precision == 16 and self.amp_backend == AMPType.APEX and not self.on_tpu:
                a, b, c = None, None, None
                error = context.__exit__(a, b, c)
                if error:
                    rank_zero_warn(a, b, c)
                    raise Exception('apex unscale error')

            # once backward has been applied, release graph
            closure_loss = closure_loss.detach()

            if is_result_obj:
                training_step_output.detach()
            else:
                training_step_output.batch_loss = training_step_output.batch_loss.detach()

        if self.use_horovod:
            # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
            optimizer.synchronize()

        # insert after step hook
        if self.is_function_implemented('on_after_backward'):
            model_ref = self.get_model()
            with self.profiler.profile('on_after_backward'):
                model_ref.on_after_backward()

        # when in dev debugging track the losses
        self.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

        result = AttributeDict(
            loss=untouched_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result
Example #2
0
    def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer,
                          hiddens):
        """
        wrap the forward step in a closure so second order methods work
        """
        # ---------------------------
        # FORWARD
        # ---------------------------
        with self.profiler.profile('model_forward'):
            if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
                with torch.cuda.amp.autocast():
                    training_step_output = self.training_forward(
                        split_batch, batch_idx, opt_idx, hiddens)
            else:
                training_step_output = self.training_forward(
                    split_batch, batch_idx, opt_idx, hiddens)

            # ----------------------------
            # PROCESS THE RESULT
            # ----------------------------
            # format and reduce outputs accordingly
            training_step_output_for_epoch_end = training_step_output
            training_step_output = self.process_output(training_step_output,
                                                       train=True)

            # TODO: temporary part of structured results PR
            training_step_output = AttributeDict(
                batch_loss=training_step_output[0],
                pbar_on_batch_end=training_step_output[1],
                log_metrics=training_step_output[2],
                callback_metrics=training_step_output[3],
                hiddens=training_step_output[4],
            )

            # if the user decides to finally reduce things in epoch_end, save raw output without graphs
            if isinstance(training_step_output_for_epoch_end, torch.Tensor):
                training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach(
                )
            else:
                training_step_output_for_epoch_end = recursive_detach(
                    training_step_output_for_epoch_end)

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches

        # the loss will get scaled for amp. avoid any modifications to it
        untouched_loss = closure_loss.detach().clone()

        # backward pass
        model_ref = self.get_model()
        with self.profiler.profile('model_backward'):
            # scale loss for 16 bit
            if self.precision == 16 and not self.on_tpu:
                closure_loss = model_ref.amp_scale_loss(
                    closure_loss, optimizer, opt_idx)

                # enter amp context
                if not NATIVE_AMP_AVALAIBLE:
                    context = closure_loss
                    closure_loss = closure_loss.__enter__()

            # do backward pass
            model_ref.backward(self, closure_loss, optimizer, opt_idx)

            # exit amp context
            if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu:
                a, b, c = None, None, None
                error = context.__exit__(a, b, c)
                if error:
                    rank_zero_warn(a, b, c)
                    raise Exception('apex unscale error')

            # once backward has been applied, release graph
            closure_loss = closure_loss.detach()
            training_step_output.batch_loss = training_step_output.batch_loss.detach(
            )

        if self.use_horovod:
            # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
            optimizer.synchronize()

        # insert after step hook
        if self.is_function_implemented('on_after_backward'):
            model_ref = self.get_model()
            with self.profiler.profile('on_after_backward'):
                model_ref.on_after_backward()

        result = AttributeDict(
            loss=untouched_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result
    def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer,
                          hiddens):
        """
        wrap the forward step in a closure so second order methods work
        """
        # ---------------------------
        # FORWARD
        # ---------------------------
        with self.profiler.profile('model_forward'):
            if self.use_amp and self.use_native_amp:
                with torch.cuda.amp.autocast():
                    training_step_output = self.training_forward(
                        split_batch, batch_idx, opt_idx, hiddens)
            else:
                training_step_output = self.training_forward(
                    split_batch, batch_idx, opt_idx, hiddens)

            # ----------------------------
            # PROCESS THE RESULT
            # ----------------------------
            # format and reduce outputs accordingly
            training_step_output = self.process_output(training_step_output,
                                                       train=True)

            # TODO: temporary part of structured results PR
            training_step_output = AttributeDict(
                batch_loss=training_step_output[0],
                pbar_on_batch_end=training_step_output[1],
                log_metrics=training_step_output[2],
                callback_metrics=training_step_output[3],
                hiddens=training_step_output[4],
            )

            # if the user decides to finally reduce things in epoch_end, save raw output without graphs
            training_step_output_for_epoch_end = recursive_detach(
                training_step_output)

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches

        # backward pass
        model_ref = self.get_model()
        with self.profiler.profile('model_backward'):
            # scale loss for 16 bit
            if self.precision == 16 and not self.on_tpu:
                closure_loss = model_ref.amp_scale_loss(
                    closure_loss, optimizer, opt_idx)

            # do backward pass
            model_ref.backward(self, closure_loss, optimizer, opt_idx)

            # once backward has been applied, release graph
            closure_loss = closure_loss.detach()
            training_step_output.batch_loss = training_step_output.batch_loss.detach(
            )

        if self.use_horovod:
            # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
            optimizer.synchronize()

        # insert after step hook
        if self.is_function_implemented('on_after_backward'):
            model_ref = self.get_model()
            with self.profiler.profile('on_after_backward'):
                model_ref.on_after_backward()

        result = AttributeDict(
            loss=closure_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result