Exemplo n.º 1
0
    def on_batch_end(self, runner: IRunner) -> None:
        """On batch end event

        Args:
            runner: current runner
        """
        # Drop the cache when we exit to a nesting level
        # that's outside any instance of autocast.
        if torch.autocast_decrement_nesting() == 0:
            torch.clear_autocast_cache()
        torch.set_autocast_enabled(self.prev_autocast_state)

        if not runner.is_train_loader:
            return

        loss = runner.batch_metrics[self.metric_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter %
                              self.accumulation_steps == 0)

        self.scaler.scale(loss).backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                grad_clip_fn=self.grad_clip_fn,
            )

            utils.maybe_recursive_call(self._optimizer, "zero_grad")
            self._accumulation_counter = 0
Exemplo n.º 2
0
 def __exit__(self, *args):
     # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
     if torch.autocast_decrement_nesting() == 0:
         torch.clear_autocast_cache()
     torch.set_autocast_cpu_enabled(self.prev)
     torch.set_autocast_cpu_dtype(self.prev_dtype)
     return False
Exemplo n.º 3
0
    def on_batch_end(self, runner: "IRunner") -> None:
        """On batch end event

        Args:
            runner: current runner
        """
        if self.use_amp:
            # Drop the cache when we exit to a nesting level
            # that's outside any instance of autocast.
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_enabled(self.prev_autocast_state)

        if not runner.is_train_loader:
            return

        loss = runner.batch_metrics[self.metric_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter %
                              self.accumulation_steps == 0)

        # @TODO: speedup with re-definition ``on_stage_start``
        if self.use_apex:
            from apex import amp

            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss,
                                self._optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        elif self.use_amp:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                grad_clip_fn=self.grad_clip_fn,
            )
            if not self.use_fast_zero_grad:
                maybe_recursive_call(self._optimizer, "zero_grad")
            else:
                maybe_recursive_call(self._optimizer, zero_grad)
            self._accumulation_counter = 0
Exemplo n.º 4
0
    def __exit__(self, exc_type: Any, exc_val: Any,
                 exc_tb: Any):  # type: ignore[override]
        if torch._jit_internal.is_scripting():
            return

        # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
        if self.device == 'cpu':
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_cpu_enabled(self.prev)
            torch.set_autocast_cpu_dtype(self.prev_fastdtype)
        else:
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_enabled(self.prev)
            torch.set_autocast_gpu_dtype(self.prev_fastdtype)
        torch.set_autocast_cache_enabled(self.prev_cache_enabled)
        return False
Exemplo n.º 5
0
            loss /= (h * w)

        # loss will be divided by (n)
        if self.batch_average:
            loss /= n

        return loss


if __name__ == '__main__':
    size = (256, 256)
    num_class = 9
    batch_size = 5

    # torch.manual_seed(1)
    torch.clear_autocast_cache()

    print('=' * 30 + ' CrossEntropy2D ' + '=' * 30)
    z = torch.randn(batch_size, num_class, *size, requires_grad=True).cuda()
    y = torch.randint(num_class, (batch_size, *size), dtype=torch.float).cuda()
    print('z.shape: {}, y.shape: {}'.format(z.shape, y.shape))
    l = CrossEntropy2D()(z, y)
    print(l.detach().cpu().numpy())

    print('=' * 30 + ' BinaryCrossEntropy2D ' + '=' * 30)
    z = torch.randn(batch_size, *size, requires_grad=True).cuda()
    y = torch.randint(2, (batch_size, *size), dtype=torch.float).cuda()
    print('z.shape: {}, y.shape: {}'.format(z.shape, y.shape))
    l = BinaryCrossEntropy2D()(z, y)
    print(l.detach().cpu().numpy())
Exemplo n.º 6
0
def train(model, trainD, evalD, checkpt=None):
    global ndecs
    optim = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             betas=(0.9, 0.99),
                             weight_decay=args.wd)
    #  sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.nepochs * trainD.N)

    if checkpt is not None:
        optim.load_state_dict(checkpt['optim'])
        ndecs = checkpt['ndecs']

    batch_time = utils.RunningAverageMeter(0.98)
    cg_meter = utils.RunningAverageMeter(0.98)
    gnorm_meter = utils.RunningAverageMeter(0.98)
    train_est_meter = utils.RunningAverageMeter(0.98**args.train_est_freq)

    best_logp = -float('inf')
    itr = 0 if checkpt is None else checkpt['iters']
    n_vals_without_improvement = 0
    model.train()
    while True:
        if itr >= args.nepochs * math.ceil(trainD.N / args.batch_size):
            break
        if 0 < args.early_stopping < n_vals_without_improvement:
            break
        for x in batch_iter(trainD.x, shuffle=True):
            if 0 < args.early_stopping < n_vals_without_improvement:
                break
            end = time.time()
            optim.zero_grad()

            x = cvt(x)
            train_est = [0] if itr % args.train_est_freq == 0 else None
            loss = -model.logp(x, extra=train_est).mean()
            if train_est is not None:
                train_est = train_est[0].mean().detach().item()

            if loss != loss:
                raise ValueError('NaN encountered @ training logp!')

            loss.backward()

            if args.clip_grad == 0:
                parameters = [
                    p for p in model.parameters() if p.grad is not None
                ]
                grad_norm = torch.norm(
                    torch.stack([
                        torch.norm(p.grad.detach(), 2.0) for p in parameters
                    ]), 2.0)
            else:
                grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), args.clip_grad)

            optim.step()
            #  sch.step()

            gnorm_meter.update(float(grad_norm))
            cg_meter.update(sum(flows.CG_ITERS_TRACER))
            flows.CG_ITERS_TRACER.clear()
            batch_time.update(time.time() - end)
            if train_est is not None:
                train_est_meter.update(train_est)

            del loss
            gc.collect()
            torch.clear_autocast_cache()

            if itr % args.log_freq == 0:
                log_message = (
                    'Iter {:06d} | Epoch {:.2f} | Time {batch_time.val:.3f} | '
                    'GradNorm {gnorm_meter.avg:.2f} | CG iters {cg_meter.val} ({cg_meter.avg:.2f}) | '
                    'Train logp {train_logp.val:.6f} ({train_logp.avg:.6f})'.
                    format(itr,
                           float(itr) / (trainD.N / float(args.batch_size)),
                           batch_time=batch_time,
                           gnorm_meter=gnorm_meter,
                           cg_meter=cg_meter,
                           train_logp=train_est_meter))
                logger.info(log_message)

            # Validation loop.
            if itr % args.val_freq == 0:
                with eval_ctx(model, bruteforce=args.brute_val):
                    val_logp = utils.AverageMeter()
                    with tqdm(total=evalD.N) as pbar:
                        # noinspection PyAssignmentToLoopOrWithParameter
                        for x in batch_iter(evalD.x,
                                            batch_size=args.val_batch_size):
                            x = cvt(x)
                            val_logp.update(
                                model.logp(x).mean().item(), x.size(0))
                            pbar.update(x.size(0))
                    if val_logp.avg > best_logp:
                        best_logp = val_logp.avg
                        utils.makedirs(args.save)
                        torch.save(
                            {
                                'args': args,
                                'model': model.state_dict(),
                                'optim': optim.state_dict(),
                                'iters': itr + 1,
                                'ndecs': ndecs,
                            }, save_path)
                        n_vals_without_improvement = 0
                    else:
                        n_vals_without_improvement += 1
                        update_lr(optim, n_vals_without_improvement)

                    log_message = ('[VAL] Iter {:06d} | Val logp {:.6f} | '
                                   'NoImproveEpochs {:02d}/{:02d}'.format(
                                       itr, val_logp.avg,
                                       n_vals_without_improvement,
                                       args.early_stopping))
                    logger.info(log_message)

            itr += 1

    logger.info('Training has finished, yielding the best model...')
    best_checkpt = torch.load(save_path)
    model.load_state_dict(best_checkpt['model'])
    return model
Exemplo n.º 7
0
def make_graphed_callables(callables, sample_args, num_warmup_iters=3):
    r"""
    Accepts callables (functions or :class:`nn.Module<torch.nn.Module>`\ s)
    and returns graphed versions.

    Each graphed callable's forward pass runs its source callable's
    forward CUDA work as a CUDA graph inside a single autograd node.

    The graphed callable's forward pass also appends
    a backward node to the autograd graph. During backward, this node runs the
    callable's backward work as a CUDA graph.

    Therefore, each graphed callable should be a drop-in replacement for its source callable
    in an autograd-enabled training loop.

    See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.

    If you pass a tuple of several callables, their captures will use the same memory pool.
    See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.

    Arguments:
        callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
            See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
            is appropriate.  If you pass a tuple of callables, their order in the tuple must be the same order
            they'll run in the live workload.
        sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
            If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
            If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
        num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
            11 iterations for warm up. Default: ``3``.

    .. note::
        The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
        that's expected for the corresponding real input in the training loop.

    .. warning::
        This API is in beta and may change in future releases.

    .. warning::
        ``sample_args`` for each callable must be a tuple of Tensors. Other types and keyword args
        are not allowed.

    .. warning::
        Returned callables do not support higher order differentiation (e.g., double backward).

    .. warning::
        In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
        may be trainable. Buffers must have ``requires_grad=False``.

    .. warning::
        After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
        you may not add or remove any of that Module's parameters or buffers.

    .. warning::
        :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
        registered on them at the time they are passed. However, registering hooks on modules *after* passing them
        through :func:`~torch.cuda.make_graphed_callables` is allowed.

    .. warning::
        When running a graphed callable, you must pass its arguments in the same order and format
        they appeared in that callable's ``sample_args``.

    .. warning::
        All Tensor outputs of graphed callables must require grad.
    """
    just_one_callable = False

    if not isinstance(callables, tuple):
        just_one_callable = True
        callables = (callables, )
        sample_args = (sample_args, )

    for c, args in zip(callables, sample_args):
        if isinstance(c, torch.nn.Module):
            assert len(c._backward_hooks) == 0 and len(c._forward_hooks) == 0 and len(c._forward_pre_hooks) == 0, \
                "Modules must not have hooks registered at the time they are passed. However, registering hooks " + \
                "on modules after passing them through make_graphed_callables is allowed."
            assert all(b.requires_grad is False for b in c.buffers()), "In any :class:`~torch.nn.Module` passed to " + \
                ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have " + \
                "``requires_grad=False``."
        assert all(isinstance(arg, torch.Tensor) for arg in args), "In the beta API, sample_args " + \
            "for each callable must be a tuple of Tensors. Other types and keyword args are not allowed."

    # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
    # passes to forward (ie, its sample_args) AND the module's parameter attributes.
    per_callable_len_user_args = [len(args) for args in sample_args]
    per_callable_module_params = [
        tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
        for c in callables
    ]
    per_callable_static_input_surfaces = [
        sample_args[i] + per_callable_module_params[i]
        for i in range(len(callables))
    ]

    fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
    bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]

    mempool = graph_pool_handle()

    # Warmup
    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
    # from ending up in any captures.
    torch.cuda.synchronize()
    with torch.cuda.stream(torch.cuda.Stream()):
        for func, args, static_input_surface in zip(
                callables, sample_args, per_callable_static_input_surfaces):
            for _ in range(num_warmup_iters):
                outputs = func(*args)
                outputs = (outputs, ) if isinstance(outputs,
                                                    torch.Tensor) else outputs
                grad_inputs = torch.autograd.grad(
                    outputs=outputs,
                    inputs=tuple(i for i in static_input_surface
                                 if i.requires_grad),
                    grad_outputs=tuple(torch.empty_like(o) for o in outputs),
                    only_inputs=True,
                    allow_unused=False)
            del outputs, grad_inputs
    torch.cuda.synchronize()

    # All captures here share a mempool. To avoid replays corrupting each other's memory,
    # the safest approach is to capture all passes in the same order they'll run:
    # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.

    # Clear AMP autocast cache before capturing the graphs
    torch.clear_autocast_cache()

    # Capture forward graphs
    per_callable_static_outputs = []
    per_callable_output_was_tensor = []
    for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
        with torch.cuda.graph(fwd_graph, pool=mempool):
            outputs = func(*args)

        # Assumes model output is a tensor or tuple of tensors
        if isinstance(outputs, torch.Tensor):
            per_callable_output_was_tensor.append(True)
            outputs = (outputs, )
        else:
            per_callable_output_was_tensor.append(False)

        per_callable_static_outputs.append(outputs)

    # Capture backward graphs in reverse order
    per_callable_static_grad_outputs = []
    per_callable_static_grad_inputs = []
    for static_input_surface, static_outputs, bwd_graph, module_params in \
            zip(reversed(per_callable_static_input_surfaces),
                reversed(per_callable_static_outputs),
                reversed(bwd_graphs),
                reversed(per_callable_module_params)):

        # For now, assumes all static_outputs require grad
        assert all(
            o.requires_grad for o in
            static_outputs), "Outputs of graphed callables must require grad."
        static_grad_outputs = tuple(
            torch.empty_like(o) for o in static_outputs)

        with torch.cuda.graph(bwd_graph, pool=mempool):
            grad_inputs = torch.autograd.grad(
                outputs=static_outputs,
                inputs=tuple(i for i in static_input_surface
                             if i.requires_grad),
                grad_outputs=static_grad_outputs,
                only_inputs=True,
                allow_unused=False)

        # Constructs a tuple suitable for returning from Graphed.backward:
        # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
        # I couldn't think of a slick one-liner for this pattern.
        static_grad_inputs = []
        grad_idx = 0
        for arg in static_input_surface:
            if arg.requires_grad:
                static_grad_inputs.append(grad_inputs[grad_idx])
                grad_idx += 1
            else:
                static_grad_inputs.append(None)  # type: ignore[arg-type]
        static_grad_inputs = tuple(
            static_grad_inputs)  # type: ignore[assignment]

        per_callable_static_grad_outputs.append(static_grad_outputs)
        per_callable_static_grad_inputs.append(static_grad_inputs)

    # Reverses the most recent two lists
    per_callable_static_grad_outputs = list(
        reversed(per_callable_static_grad_outputs))
    per_callable_static_grad_inputs = list(
        reversed(per_callable_static_grad_inputs))
    # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.

    # Clear AMP autocast cache after both forward and backward graphs are captured
    torch.clear_autocast_cache()

    def make_graphed_autograd_function(fwd_graph, bwd_graph, module_params,
                                       len_user_args, output_was_tensor,
                                       static_input_surface, static_outputs,
                                       static_grad_outputs,
                                       static_grad_inputs):
        class Graphed(torch.autograd.Function):
            @staticmethod
            def forward(ctx, *inputs):
                # At this stage, only the user args may (potentially) be new tensors.
                for i in range(len_user_args):
                    if static_input_surface[i].data_ptr(
                    ) != inputs[i].data_ptr():
                        static_input_surface[i].copy_(inputs[i])
                fwd_graph.replay()
                assert isinstance(static_outputs, tuple)
                return tuple(o.detach() for o in static_outputs)

            @staticmethod
            @torch.autograd.function.once_differentiable
            def backward(ctx, *grads):
                for g, grad in zip(static_grad_outputs, grads):
                    if g is None:
                        assert grad is None
                    else:
                        # don't copy if autograd gods have been kind and the
                        # incoming grad is already in the right place
                        if g.data_ptr() != grad.data_ptr():
                            g.copy_(grad)
                bwd_graph.replay()

                # Input args that didn't require grad expect a None gradient.
                assert isinstance(static_grad_inputs, tuple)
                return tuple(b.detach() if b is not None else b
                             for b in static_grad_inputs)

        def functionalized(*user_args):
            # Runs the autograd function with inputs == all inputs to the graph that might require grad
            # (explicit user args + module parameters)
            # Assumes module params didn't change since capture.
            out = Graphed.apply(*(user_args + module_params))
            return out[0] if output_was_tensor else out

        return functionalized

    # Put together the final graphed callables
    ret = []
    for i, func in enumerate(callables):
        graphed = make_graphed_autograd_function(
            fwd_graphs[i], bwd_graphs[i], per_callable_module_params[i],
            per_callable_len_user_args[i], per_callable_output_was_tensor[i],
            per_callable_static_input_surfaces[i],
            per_callable_static_outputs[i],
            per_callable_static_grad_outputs[i],
            per_callable_static_grad_inputs[i])

        if isinstance(func, torch.nn.Module):

            def make_graphed_forward(func, graph_training_state, graphed,
                                     orig_fwd):
                def new_fwd(*user_args):
                    # If the module's training-or-eval state matches what we graphed,
                    # run the graph, otherwise run the original forward method
                    if func.training == graph_training_state:
                        return graphed(*user_args)
                    else:
                        return orig_fwd(*user_args)

                return new_fwd

            func.forward = make_graphed_forward(
                func, func.training, graphed,
                func.forward)  # type: ignore[assignment]
            ret.append(func)
        else:
            ret.append(graphed)

    if just_one_callable:
        return ret[0]

    return tuple(ret)