Exemplo n.º 1
0
 def test_helper_train(self):
     """
     Tests train/eval mode helper methods
     """
     rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
     rnn.train()
     self.assertTrue(torch.is_grad_enabled())
     self.assertTrue(rnn.nn.training)
     rnn.eval()
     self.assertFalse(torch.is_grad_enabled())
     self.assertFalse(rnn.nn.training)
Exemplo n.º 2
0
 def wrapper(ctx, *args):
     tensor_args = [arg.data if isinstance(arg, Variable) else arg
                    for arg in args]
     outputs = fn(ctx, *tensor_args)
     # XXX: this is only an approximation of these flags - there's no way
     # to figure out if fn didn't use ctx.saved_variables and as a result
     # some Variables might require grad, even if no args do.
     # Unfortunately, this leads to unexpected error messages ("no nodes
     # require computing gradients"), but I don't have a better idea.
     # These functions would raise an error in backward anyway.
     requires_grad = any(arg.requires_grad if isinstance(arg, Variable) else False
                         for arg in args)
     if not torch.is_grad_enabled():
         def err_fn(*args):
             return args
     else:
         err_fn = torch._C._functions.DelayedError(
             b"trying to differentiate twice a function that was marked"
             b"with @once_differentiable")
     if not isinstance(outputs, tuple):
         var = (Variable(outputs, requires_grad=requires_grad)
                if outputs is not None else None)
         return err_fn(var)
     return err_fn(*[Variable(o, requires_grad=requires_grad) if o is not None else None
                   for o in outputs])
Exemplo n.º 3
0
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.

    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)

    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                output = module(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, kwargs, device))
                   for i, (module, input, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
Exemplo n.º 4
0
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)

    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                output = module(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, kwargs, device))
                   for i, (module, input, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
Exemplo n.º 5
0
    def wrapper(ctx, *args):
        with torch.no_grad():
            outputs = fn(ctx, *args)

        if not torch.is_grad_enabled():
            return outputs

        # If any of the inputs have requires_grad=True, we force the outputs
        # to have requires_grad=True but point to a grad_fn which throws an
        # error message during (double) back-propagation.
        # XXX: this is only an approximation of requires_grad - there's no way
        # to figure out if fn didn't use ctx.saved_variables and as a result
        # some Variables might require grad, even if no args do.
        # Unfortunately, this leads to unexpected error messages ("no nodes
        # require computing gradients"), but I don't have a better idea.
        # These functions would raise an error in backward anyway.
        requires_grad = any(isinstance(arg, Variable) and arg.requires_grad
                            for arg in args)
        if not requires_grad:
            return outputs

        err_fn = torch._C._functions.DelayedError(
            b"trying to differentiate twice a function that was marked"
            b"with @once_differentiable")

        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        # Create aliases of each output that has requires_grad=True. We need
        # at least one of the inputs to err_fn to require grad so that the
        # output will have a grad_fn.
        def fake_requires_grad(var):
            if var is not None:
                var = var.detach()
                var.requires_grad = True
            return var

        return err_fn(*[fake_requires_grad(v) for v in outputs])
Exemplo n.º 6
0
def parallel_apply_sampling(modules, inputs, kwargs_tup=None, devices=None):
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.
    Args:
        modules (Module): modules to be parallelized
        inputs (tensor): inputs to the modules
        devices (list of int or torch.device): CUDA devices
    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                output = module.sampling(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, kwargs, device))
                   for i, (module, input, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
Exemplo n.º 7
0
def sinkhorn_loop(
    softmin,
    α_logs,
    β_logs,
    C_xxs,
    C_yys,
    C_xys,
    C_yxs,
    ε_s,
    ρ,
    jumps=[],
    kernel_truncation=None,
    truncate=5,
    cost=None,
    extrapolate=None,
    debias=True,
    last_extrapolation=True,
):

    Nits = len(ε_s)
    if type(α_logs) is not list:
        α_logs, β_logs = [α_logs], [β_logs]
        if debias:
            C_xxs, C_yys = [C_xxs], [C_yys]
        C_xys, C_yxs = [C_xys], [C_yxs]

    prev_autograd = torch.is_grad_enabled()
    torch.autograd.set_grad_enabled(False)

    k = 0  # Scale index; we start at the coarsest resolution available
    ε = ε_s[k]
    λ = dampening(ε, ρ)

    # Load the measures and cost matrices at the current scale:
    α_log, β_log = α_logs[k], β_logs[k]
    if debias:
        C_xx, C_yy = C_xxs[k], C_yys[k]
    C_xy, C_yx = C_xys[k], C_yxs[k]

    # Start with a decent initialization for the dual vectors:
    if debias:
        a_x = λ * softmin(ε, C_xx, α_log)  # OT(α,α)
        b_y = λ * softmin(ε, C_yy, β_log)  # OT(β,β)
    a_y = λ * softmin(ε, C_yx, α_log)  # OT(α,β) wrt. a
    b_x = λ * softmin(ε, C_xy, β_log)  # OT(α,β) wrt. b

    for i, ε in enumerate(ε_s):  # ε-scaling descent -----------------------

        λ = dampening(ε, ρ)  # ε has changed, so we should update λ too!

        # "Coordinate ascent" on the dual problems:
        if debias:
            at_x = λ * softmin(ε, C_xx, α_log + a_x / ε)  # OT(α,α)
            bt_y = λ * softmin(ε, C_yy, β_log + b_y / ε)  # OT(β,β)
        at_y = λ * softmin(ε, C_yx, α_log + b_x / ε)  # OT(α,β) wrt. a
        bt_x = λ * softmin(ε, C_xy, β_log + a_y / ε)  # OT(α,β) wrt. b

        # Symmetrized updates:
        if debias:
            a_x, b_y = 0.5 * (a_x + at_x), 0.5 * (b_y + bt_y
                                                  )  # OT(α,α), OT(β,β)
        a_y, b_x = 0.5 * (a_y + at_y), 0.5 * (b_x + bt_x)  # OT(α,β) wrt. a, b

        if i in jumps:  # Jump from a coarse to a finer scale --------------

            if i == len(ε_s) - 1:  # Last iteration: just extrapolate!

                if debias:
                    C_xx_, C_yy_ = C_xxs[k + 1], C_yys[k + 1]
                C_xy_, C_yx_ = C_xys[k + 1], C_yxs[k + 1]

                last_extrapolation = False  # No need to re-extrapolate after the loop
                torch.autograd.set_grad_enabled(prev_autograd)

            else:  # It's worth investing some time on kernel truncation...

                # Kernel truncation trick (described in Bernhard Schmitzer's 2016 paper),
                # that typically relies on KeOps' block-sparse routines:
                if debias:
                    C_xx_, _ = kernel_truncation(
                        C_xx,
                        C_xx,
                        C_xxs[k + 1],
                        C_xxs[k + 1],
                        a_x,
                        a_x,
                        ε,
                        truncate=truncate,
                        cost=cost,
                    )
                    C_yy_, _ = kernel_truncation(
                        C_yy,
                        C_yy,
                        C_yys[k + 1],
                        C_yys[k + 1],
                        b_y,
                        b_y,
                        ε,
                        truncate=truncate,
                        cost=cost,
                    )
                C_xy_, C_yx_ = kernel_truncation(
                    C_xy,
                    C_yx,
                    C_xys[k + 1],
                    C_yxs[k + 1],
                    b_x,
                    a_y,
                    ε,
                    truncate=truncate,
                    cost=cost,
                )

            # Extrapolation for the symmetric problems:
            if debias:
                a_x = extrapolate(a_x, a_x, ε, λ, C_xx, α_log, C_xx_)
                b_y = extrapolate(b_y, b_y, ε, λ, C_yy, β_log, C_yy_)

            # The cross-updates should be done in parallel!
            a_y, b_x = extrapolate(a_y, b_x, ε, λ, C_yx, α_log,
                                   C_yx_), extrapolate(b_x, a_y, ε, λ, C_xy,
                                                       β_log, C_xy_)

            # Update the measure weights and cost "matrices":
            k = k + 1
            α_log, β_log = α_logs[k], β_logs[k]
            if debias:
                C_xx, C_yy = C_xx_, C_yy_
            C_xy, C_yx = C_xy_, C_yx_

    torch.autograd.set_grad_enabled(prev_autograd)

    if last_extrapolation:
        # Last extrapolation, to get the correct gradients:
        if debias:
            a_x = λ * softmin(ε, C_xx, (α_log + a_x / ε).detach())
            b_y = λ * softmin(ε, C_yy, (β_log + b_y / ε).detach())

        # The cross-updates should be done in parallel!
        a_y, b_x = λ * softmin(ε, C_yx,
                               (α_log + b_x / ε).detach()), λ * softmin(
                                   ε, C_xy, (β_log + a_y / ε).detach())

    if debias:
        return a_x, b_y, a_y, b_x
    else:
        return None, None, a_y, b_x
Exemplo n.º 8
0
    def backward(ctx, grad_yt):
        # grad_yt: (nt, *ny)
        nparams = ctx.nparams
        pfcn = ctx.pfcn
        param_sep = ctx.param_sep
        yt = ctx.yt
        ts_requires_grad = ctx.ts_requires_grad

        # restore the parameters
        saved_tensors = ctx.saved_tensors
        ts = saved_tensors[0]
        y0 = saved_tensors[1]
        tensor_params = list(saved_tensors[2:])
        allparams = param_sep.reconstruct_params(tensor_params)
        ntensor_params = len(tensor_params)
        params = allparams[:nparams]
        objparams = allparams[nparams:]

        grad_enabled = torch.is_grad_enabled()

        # custom function to evaluate the input `pfcn` based on whether we want
        # to connect the graph or not

        def pfunc2(t, y, tensor_params):
            if not grad_enabled:
                # if graph is not constructed, then use the default tensor_params
                ycopy = y.detach().requires_grad_(
                )  # [yi.detach().requires_grad_() for yi in y]
                tcopy = t.detach().requires_grad_()
                f = pfcn(tcopy, ycopy, *params)
                return f, tcopy, ycopy, tensor_params
            else:
                # if graph is constructed, then use the clone of the tensor params
                # so that infinite loop of backward can be avoided
                tensor_params_copy = [
                    p.clone().requires_grad_() for p in tensor_params
                ]
                ycopy = y.clone().requires_grad_()
                tcopy = t.clone().requires_grad_()
                allparams_copy = param_sep.reconstruct_params(
                    tensor_params_copy)
                params_copy = allparams_copy[:nparams]
                objparams_copy = allparams_copy[nparams:]
                with pfcn.useobjparams(objparams_copy):
                    f = pfcn(tcopy, ycopy, *params_copy)
                return f, tcopy, ycopy, tensor_params_copy

        # slices and indices definitions on the augmented states
        y_index = 0
        dLdy_index = 1
        dLdt_index = 2
        dLdt_slice = slice(dLdt_index, dLdt_index + 1, None)  # [2:3]
        dLdp_slice = slice(-ntensor_params,
                           None, None) if ntensor_params > 0 else slice(
                               0, 0, None)  # [-ntensor_params:]
        state_size = 3 + ntensor_params
        states = [None for _ in range(state_size)]

        def new_pfunc(t, states, *tensor_params):
            # t: single-element
            y = states[y_index]
            dLdy = -states[dLdy_index]
            with torch.enable_grad():
                f, t2, y2, tensor_params2 = pfunc2(t, y, tensor_params)
            allgradinputs = ([y2] + [t2] + list(tensor_params2))
            allgrads = torch.autograd.grad(
                f,
                inputs=allgradinputs,
                grad_outputs=dLdy,
                retain_graph=True,
                allow_unused=True,
                create_graph=torch.is_grad_enabled())  # list of (*ny)
            allgrads = convert_none_grads_to_zeros(allgrads, allgradinputs)
            outs = (
                f,  # dydt
                *allgrads,
            )
            return outs

        ts_flip = ts.flip(0)
        t_flip_idx = -1
        states[y_index] = yt[t_flip_idx]
        states[dLdy_index] = grad_yt[t_flip_idx]
        states[dLdt_index] = torch.zeros_like(ts[0])
        states[dLdp_slice] = [torch.zeros_like(tp) for tp in tensor_params]
        grad_ts = [None for _ in range(len(ts))] if ts_requires_grad else None

        for i in range(len(ts_flip) - 1):
            if ts_requires_grad:
                feval = pfunc2(ts_flip[i], states[y_index], tensor_params)[0]
                dLdt1 = torch.dot(feval.reshape(-1),
                                  grad_yt[t_flip_idx].reshape(-1))
                states[dLdt_index] -= dLdt1
                grad_ts[t_flip_idx] = dLdt1.reshape(-1)

            t_flip_idx -= 1
            outs = solve_ivp(new_pfunc,
                             ts_flip[i:i + 2],
                             states,
                             tensor_params,
                             fwd_options=ctx.bck_config,
                             bck_options=ctx.bck_config)
            # only take the output for the earliest time
            states = [out[-1] for out in outs]
            states[y_index] = yt[t_flip_idx]
            # gyt is the contribution from the input grad_y
            # gy0 is the propagated gradients from the later time step
            states[dLdy_index] = grad_yt[t_flip_idx] + states[dLdy_index]

        if ts_requires_grad:
            grad_ts[0] = states[dLdt_index].reshape(-1)

        grad_y0 = states[dLdy_index]  # dL/dy0, (*ny)
        if ts_requires_grad:
            grad_ts = torch.cat(grad_ts).reshape(*ts.shape)
        grad_tensor_params = states[dLdp_slice]
        grad_ntensor_params = [
            None for _ in range(len(allparams) - ntensor_params)
        ]
        grad_params = param_sep.reconstruct_params(grad_tensor_params,
                                                   grad_ntensor_params)
        return (None, grad_ts, None, None, None, grad_y0, *grad_params)
Exemplo n.º 9
0
def parallel_apply(modules,
                   inputs,
                   kwargs_tup=None,
                   devices=None):  # pragma: no-cover
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.

    Args:
        modules (Module): modules to be parallelized
        inputs (tensor): inputs to the modules
        devices (list of int or torch.device): CUDA devices

    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({}, ) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        fx_called: str = ''
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input, )

                module = module.to(device)

                # ---------------
                # CHANGE
                if module.training:
                    output = module.training_step(*input, **kwargs)
                    fx_called = 'training_step'
                elif module.testing:
                    output = module.test_step(*input, **kwargs)
                    fx_called = 'test_step'
                else:
                    output = module.validation_step(*input, **kwargs)
                    fx_called = 'validation_step'

                if output is None:
                    warn_missing_output(fx_called)

                if output is not None and (module.use_dp or module.use_ddp2):
                    auto_squeeze_dim_zeros(output)
                # ---------------

            with lock:
                results[i] = output
        except Exception as ex:
            with lock:
                results[i] = ex

    # TODO: fix hack (maybe not a hack)
    # make sure each module knows what training state it's in...
    # fixes weird bug where copies are out of sync
    root_m = modules[0]
    for m in modules[1:]:
        m.training = root_m.training
        m.testing = root_m.testing

    if len(modules) > 1:
        threads = [
            threading.Thread(target=_worker,
                             args=(i, module, input, kwargs, device))
            for i, (
                module, input, kwargs,
                device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
        ]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
Exemplo n.º 10
0
 def __enter__(self):
     self.prev = torch.is_grad_enabled()
     torch._C.set_grad_enabled(False)
Exemplo n.º 11
0
async def completed(trace_name='',
                    name='',
                    sleep_interval=0.05,
                    streams: List[torch.cuda.Stream] = None):
    """
    Async context manager that waits for work to complete on
    given CUDA streams.

    """
    if not torch.cuda.is_available():
        yield
        return

    stream_before_context_switch = torch.cuda.current_stream()
    if not streams:
        streams = [stream_before_context_switch]
    else:
        streams = [s if s else stream_before_context_switch for s in streams]

    end_events = [
        torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
    ]

    if DEBUG_COMPLETED_TIME:
        start = torch.cuda.Event(enable_timing=True)
        stream_before_context_switch.record_event(start)

        cpu_start = time.monotonic()
    logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
    grad_enabled_before = torch.is_grad_enabled()
    try:
        yield
    finally:
        current_stream = torch.cuda.current_stream()
        assert current_stream == stream_before_context_switch

        if DEBUG_COMPLETED_TIME:
            cpu_end = time.monotonic()
        for i, stream in enumerate(streams):
            event = end_events[i]
            stream.record_event(event)

        grad_enabled_after = torch.is_grad_enabled()

        # observed change of torch.is_grad_enabled() during concurrent run of
        # async_test_bboxes code
        assert (grad_enabled_before == grad_enabled_after
                ), 'Unexpected is_grad_enabled() value change'

        are_done = [e.query() for e in end_events]
        logger.debug('%s %s completed: %s streams: %s', trace_name, name,
                     are_done, streams)
        with torch.cuda.stream(stream_before_context_switch):
            while not all(are_done):
                await asyncio.sleep(sleep_interval)
                are_done = [e.query() for e in end_events]
                logger.debug(
                    '%s %s completed: %s streams: %s',
                    trace_name,
                    name,
                    are_done,
                    streams,
                )

        current_stream = torch.cuda.current_stream()
        assert current_stream == stream_before_context_switch

        if DEBUG_COMPLETED_TIME:
            cpu_time = (cpu_end - cpu_start) * 1000
            stream_times_ms = ''
            for i, stream in enumerate(streams):
                elapsed_time = start.elapsed_time(end_events[i])
                stream_times_ms += ' {} {:.2f} ms'.format(stream, elapsed_time)
            logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
                        stream_times_ms)
Exemplo n.º 12
0
def is_recording():
    return th.is_grad_enabled()
Exemplo n.º 13
0
def run_one_epoch_aae(model,
                      x,
                      y,
                      num_critic=1,
                      clip_value=0.01,
                      train=True,
                      optimizer=None,
                      batch_size=None,
                      return_loss=True,
                      loss_weight=[1., 1., 1.],
                      loss_fn_cls=nn.CrossEntropyLoss(),
                      loss_fn_reg=nn.MSELoss(),
                      loss_fn_critic=nn.L1Loss(),
                      epoch=0,
                      print_every=1,
                      verbose=True,
                      forward_kwargs={}):
    """Run one epoch for Adversarial AutoEncoder (AAE) model using modified Wasserstein GAN loss
    Note this implementation is based on AAE and Wasserstein GAN but have been modified
    Provide the same interface as run_one_epoch_single_loss

  Args:
    num_critic: how often do we need to update 
    loss_weight: default [1., 1., 1.], corresponding to the losses 
      for classification, reconstruction, and discriminator losses
    all other arguments are the same as run_one_epoch_single_loss
  """
    is_grad_enabled = torch.is_grad_enabled()
    if train:
        model.train()
        torch.set_grad_enabled(True)
    else:
        model.eval()
        torch.set_grad_enabled(False)
    if batch_size is None:
        batch_size = len(x)
    if loss_weight is None:
        loss_weight = [1., 1., 1.]
    total_loss = 0
    acc = 0
    loss_batches = []
    for batch_idx, i in enumerate(range(0, len(x), batch_size)):
        x_batch = x[i:i + batch_size]
        y_batch = y[i:i + batch_size]
        cls_score, x_bar, critic_data, critic_prior = model(
            x_batch, **forward_kwargs)
        loss_cls = loss_fn_cls(cls_score, y_batch)
        loss_reg = loss_fn_reg(x_bar, x_batch)  # reconstruction loss
        # This is different Wasserstein GAN; I used sigmoid_() to control the loss scale
        # I used nn.L1Loss() but Wasserstein GAN does not use the absolute value
        loss_discriminator = loss_fn_critic(critic_prior.sigmoid_(),
                                            critic_data.sigmoid_())
        loss = loss_cls * loss_weight[0] + loss_reg * loss_weight[
            1] + loss_discriminator * loss_weight[2]
        loss_batch = [
            loss_cls.item(),
            loss_reg.item(),
            loss_discriminator.item()
        ]
        loss_batches.append(loss_batch)
        total_loss += loss.item()
        acc_batch = (cls_score.topk(1)[1] == y_batch.unsqueeze(1)
                     ).float().mean().item()
        acc += acc_batch
        if verbose:
            msg = 'Epoch {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x_batch), len(x), 100. * batch_idx /
                (len(x_batch) + (batch_size - 1)) // batch_size,
                loss.item() / len(x_batch))
            msg += f' Acc={acc_batch:.2f}'
            print(msg)
        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            torch.nn.utils.clip_grad_value_(model.discriminator.parameters(),
                                            clip_value)
        if epoch % num_critic == 0:
            # Note this is different from Wasserstein GAN and AAE
            critic_prior = model.discriminator(
                critic_prior.new_tensor(
                    torch.randn(batch_size, model.latent_dim))).sigmoid_()
            loss_discriminator = -critic_prior.mean()
            if verbose:
                print(
                    f'Epoch {epoch}\tloss_discriminator={loss_discriminator.item()}'
                )
            if train:
                optimizer.zero_grad()
                loss_discriminator.backward()
                optimizer.step()
    torch.set_grad_enabled(is_grad_enabled)

    total_loss /= len(x)  # now total loss is average loss
    acc /= (len(x) + (batch_size - 1)) // batch_size
    if epoch % print_every == 0:
        print('Epoch{} {}: loss={:.3e}, acc={:.2f}'.format(
            epoch, 'Train' if train else ' Test', total_loss, acc))
    if return_loss:
        return total_loss, acc
Exemplo n.º 14
0
def run_one_epoch_single_loss(model, x, y_true, loss_fn=nn.CrossEntropyLoss(), train=True, optimizer=None, 
  batch_size=None, return_loss=True, epoch=0, print_every=10, verbose=True, forward_kwargs={}):
  """Run one epoch, i.e., model(x), but split into batches
  
  Args:
    model: torch.nn.Module
    x: torch.Tensor
    y_true: target torch.Tensor
    loss_fn: loss function
    train: if False, call model.eval() and torch.set_grad_enabled(False) to save time
    optimizer: needed when train is True
    batch_size: if None, batch_size = len(x)
    return_loss: if True, return epoch loss
    epoch: for print 
    print_every: print epoch_loss if print_every % epoch == 0
    verbose: if True, print batch_loss
    forward_kwargs: default {}, used for model(x, **forward_kwargs), provide additional kwargs for forward pass;
      if it is sample-related, then batch_size should be None, otherwise there can be size mismatch
  """

  is_grad_enabled = torch.is_grad_enabled()
  if train:
    model.train()
    torch.set_grad_enabled(True)
  else:
    model.eval()
    torch.set_grad_enabled(False)
  loss_history = []
  is_classification = isinstance(y_true.cpu(), torch.LongTensor)
  if is_classification:
    acc_history = []
  if batch_size is None:
    batch_size = len(x)
  for i in range(0, len(x), batch_size):
    y_pred = model(x[i:i+batch_size], **forward_kwargs)
    loss = loss_fn(y_pred, y_true[i:i+batch_size])
    loss_history.append(loss.item())
    if is_classification:
      labels_pred = y_pred.topk(1, -1)[1].squeeze() # only calculate top 1 accuracy
      acc = (labels_pred == y_true[i:i+batch_size]).float().mean().item()
      acc_history.append(acc)
    if verbose:
      msg = 'Epoch{} {}/{}: loss={:.2e}'.format(
        epoch, i//batch_size, (len(x)+batch_size-1)//batch_size, loss.item())
      if is_classification:
        msg = msg + f', acc={acc:.2f}'
      print(msg)
    if train:
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
  torch.set_grad_enabled(is_grad_enabled)

  loss_epoch = np.mean(loss_history)
  if is_classification:
    acc_epoch = np.mean(acc_history)
  if epoch % print_every == 0:  
    msg = 'Epoch{} {}: loss={:.2e}'.format(epoch, 'Train' if train else 'Test', np.mean(loss_history))
    if is_classification:
      msg = msg + f', acc={np.mean(acc_history):.2f}'
    print(msg)
  if return_loss:
    if is_classification:
      return loss_epoch, acc_epoch, loss_history, acc_history
    else:
      return loss_epoch, loss_history
Exemplo n.º 15
0
    def backward(ctx, grad_epf):
        # restore the parameters
        alltensors = ctx.saved_tensors
        nftensorparams = ctx.nftensorparams
        nptensorparams = ctx.nptensorparams
        epf = alltensors[0]
        ftensor_params = alltensors[1:1 + nftensorparams]
        ptensor_params = alltensors[1 + nftensorparams:]
        fptensor_params = alltensors[1:]

        # get the parameters and the object parameters
        nfparams = ctx.nfparams
        npparams = ctx.npparams
        fall_params = ctx.fparam_sep.reconstruct_params(ftensor_params)
        pall_params = ctx.pparam_sep.reconstruct_params(ptensor_params)
        fparams = fall_params[:nfparams]
        fobjparams = fall_params[nfparams:]
        pparams = pall_params[:npparams]
        pobjparams = pall_params[npparams:]

        # get other things from the forward
        ffcn = ctx.ffcn
        log_pfcn = ctx.log_pfcn
        xsamples = ctx.xsamples
        wsamples = ctx.wsamples
        grad_enabled = torch.is_grad_enabled()

        def function_wrap(fcn, param_sep, nparams, x, tensor_params):
            all_params = param_sep.reconstruct_params(tensor_params)
            params = all_params[:nparams]
            objparams = all_params[nparams:]
            with fcn.useobjparams(objparams):
                f = fcn(x, *params)
            return f

        def aug_function(x, *grad_and_fptensor_params):
            local_grad_enabled = torch.is_grad_enabled()
            grad_epf = grad_and_fptensor_params[0]
            epf = grad_and_fptensor_params[1]
            fptensor_params = grad_and_fptensor_params[2:]
            ftensor_params = fptensor_params[:nftensorparams]
            ptensor_params = fptensor_params[nftensorparams:]
            with torch.enable_grad():
                # if graph is constructed, then fptensor_params is a clone of
                # fptensor_params from outside, therefore, it needs to be put
                # in the pure function's objects (that's what function_wrap does)
                if grad_enabled:
                    fout = function_wrap(ffcn, ctx.fparam_sep, nfparams, x, ftensor_params)
                    pout = function_wrap(log_pfcn, ctx.pparam_sep, npparams, x, ptensor_params)
                # if graph is not constructed, then fptensor_params in this
                # function *is* fptensor_params in the outside, so we can
                # just use fparams and pparams from the outside
                else:
                    fout = ffcn(x, *fparams)
                    pout = log_pfcn(x, *pparams)
            # derivative of fparams
            dLdthetaf = []
            if len(ftensor_params) > 0:
                dLdthetaf = torch.autograd.grad(fout, ftensor_params,
                                                grad_outputs=grad_epf,
                                                retain_graph=True,
                                                create_graph=local_grad_enabled)
            # derivative of pparams
            dLdthetap = []
            if len(ptensor_params) > 0:
                dLdef = torch.dot((fout - epf).reshape(-1), grad_epf.reshape(-1))
                dLdthetap = torch.autograd.grad(pout, ptensor_params,
                                                grad_outputs=dLdef.reshape(pout.shape),
                                                retain_graph=True,
                                                create_graph=local_grad_enabled)
            # combine the states needed for backward
            outs = (
                *dLdthetaf,
                *dLdthetap,
            )
            return outs

        if grad_enabled:
            fptensor_params_copy = [y.clone().requires_grad_() for y in fptensor_params]
        else:
            fptensor_params_copy = fptensor_params

        aug_epfs = _mcquad(aug_function, log_pfcn,
                           x0=xsamples[0],  # unused because xsamples is set
                           xsamples=xsamples,
                           wsamples=wsamples,
                           fparams=(grad_epf, epf, *fptensor_params_copy),
                           pparams=pparams,
                           method=ctx.method,
                           bck_options=ctx.bck_config,
                           **ctx.bck_config)
        dLdthetaf = aug_epfs[:nftensorparams]
        dLdthetap = aug_epfs[nftensorparams:]

        # combine the gradient for all fparams
        dLdfnontensor = [None for _ in range(ctx.fparam_sep.nnontensors())]
        dLdpnontensor = [None for _ in range(ctx.pparam_sep.nnontensors())]
        dLdtf = ctx.fparam_sep.reconstruct_params(dLdthetaf, dLdfnontensor)
        dLdtp = ctx.pparam_sep.reconstruct_params(dLdthetap, dLdpnontensor)
        return (None, None, None, None, None, None, None, None, None, None, None,
                *dLdtf, *dLdtp)
Exemplo n.º 16
0
def run_one_epoch_vae(model,
                      x,
                      y,
                      num_cls=2,
                      train=True,
                      optimizer=None,
                      batch_size=None,
                      return_loss=True,
                      epoch=0,
                      print_every=10,
                      verbose=True,
                      forward_kwargs={}):
    """Run one epoch for VAE model
  Almost the same as run_one_epoch_single_loss

  Args:
    num_cls: as long as it is bigger than 1, perform classification
    all other arguments are the same as run_one_epoch_single_loss
  """
    is_grad_enabled = torch.is_grad_enabled()
    if train:
        model.train()
        torch.set_grad_enabled(True)
    else:
        model.eval()
        torch.set_grad_enabled(False)
    if batch_size is None:
        batch_size = len(x)
    total_loss = 0
    acc = 0
    for batch_idx, i in enumerate(range(0, len(x), batch_size)):
        x_batch = x[i:i + batch_size]
        y_batch = y[i:i + batch_size]
        y_pred = model(x_batch, **forward_kwargs)
        if num_cls > 1:
            loss = loss_vae(*y_pred, x_batch, y_batch)
            cls_score = y_pred[0]
            acc_batch = (cls_score.topk(1)[1] == y_batch.unsqueeze(1)
                         ).float().mean().item()
            acc += acc_batch
        else:
            loss = loss_vae(None, *y_pred, x_batch, y_batch)
        total_loss += loss.item()
        if verbose:
            msg = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x_batch), len(x), 100. * batch_idx /
                (len(x_batch) + (batch_size - 1)) // batch_size,
                loss.item() / len(x_batch))
            if num_cls > 1:
                msg += f' Acc={acc_batch:.2f}'
            print(msg)
        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    torch.set_grad_enabled(is_grad_enabled)

    total_loss /= len(x)  # now total loss is average loss
    acc /= (len(x) + (batch_size - 1)) // batch_size
    if epoch % print_every == 0:
        print('Epoch{} {}: loss={:.3e}, acc={:.2f}'.format(
            epoch, 'Train' if train else ' Test', total_loss, acc))
    if return_loss:
        return total_loss, acc
Exemplo n.º 17
0
 def _save_input(self, module, input):
     """Hook for saving layer input"""
     if torch.is_grad_enabled() and self.steps % self.fac_update_freq == 0:
         self.m_a[module] = input[0].data
Exemplo n.º 18
0
 def is_training(self):
     return self._original_module.training and torch.is_grad_enabled()
 def _compute_forward_factor(self, module, input):
     if torch.is_grad_enabled() and self.steps % self.fac_update_freq == 0:
         self.m_a[module] = input[0].data
         self._update_module_A(module)
Exemplo n.º 20
0
    def returned_function(*args, **kwargs):
        global compile_cache
        nonlocal cached_res

        # Separate out static args if static_argnums is present
        tensor_args = args
        static_args = []
        # TODO - move the hashing part of static_args to C++.
        static_args_hashed = []
        if static_argnums is not None:
            (
                tensor_args,
                static_args,
                static_args_hashed,
            ) = filter_tensor_and_static_args(args, static_argnums)

        # Now flatten the tensor args
        if HAS_TREE:
            flat_tensor_args = tree.flatten((tensor_args, kwargs))
        else:
            flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs))

        # Check if the fn is already compiled
        num_tensor_args = len(flat_tensor_args)
        flat_args_for_cache = flat_tensor_args + static_args_hashed
        cached_res = compile_cache.at(
            fn_id,
            fw_compiler_id,
            bw_compiler_id,
            num_tensor_args,
            hasher_type,
            *flat_args_for_cache,
        )

        # Compile the function and save it in the cache
        if cached_res is None:
            # Save the args_spec for flat_tensor_args to unflatten while tracing
            _, tensor_args_spec = pytree.tree_flatten((tensor_args, kwargs))
            out_spec = PytreeThunk()

            def flat_fn(*flat_tensor_args):
                # The input are flattened tensor args. Prepare the args in the
                # order that original function expects. Add static args as well.
                # They will appear as tensor constants in the traced graph.
                nonlocal out_spec, static_args

                tensor_args, kwargs = pytree.tree_unflatten(
                    flat_tensor_args, tensor_args_spec
                )
                if static_argnums is None:
                    args = tensor_args
                else:
                    args = rearrange(tensor_args, static_args, static_argnums)
                tree_out = fn(*args, **kwargs)
                flat_out, spec = pytree.tree_flatten(tree_out)
                for i in flat_out:
                    is_known_type = False
                    for j in KNOWN_TYPES:
                        if isinstance(i, j):
                            is_known_type = True
                            break
                    if not is_known_type:
                        raise RuntimeError(
                            f"Found {type(i)} in output, which is not a known type. "
                            "If this type holds tensors, you need to register a pytree for it. "
                            "See https://github.com/pytorch/functorch/issues/475 for a brief "
                            "explanation why. If you don't need to register a pytree, please "
                            "leave a comment explaining your use case and we'll make this more "
                            "ergonomic to deal with"
                        )
                out_spec.set(spec)
                return flat_out

            compiled_fn = create_aot_autograd_function(
                flat_fn,
                fw_compiler,
                bw_compiler,
                partition_fn,
                decompositions,
                grad_state=torch.is_grad_enabled(),
            ).apply
            cached_res = (compiled_fn, out_spec)

            # Save the compiled_fn in the cache
            compile_cache.insert(
                fn_id,
                fw_compiler_id,
                bw_compiler_id,
                num_tensor_args,
                hasher_type,
                cached_res,
                *flat_args_for_cache,
            )

        cached_fn, out_spec = cached_res
        out = cached_fn(*flat_tensor_args)
        return out_spec.unflatten(out)
Exemplo n.º 21
0
def run_one_epoch_multiloss(model, x, targets, heads=[0,1], loss_fns=[nn.CrossEntropyLoss(), nn.MSELoss()], 
  loss_weights=[1,0], other_loss_fns=[], other_loss_weights=[], return_loss=True, batch_size=None, 
  train=True, optimizer=None, epoch=0, print_every=10, verbose=True):
  """Calculate a multi-head model with multiple losses including losses from the outputs and targets (head losses) 
  and regularizers on model parameters (non-head losses).
  
  Args:
    model: A model with multihead; for example, an AutoEncoder classifier, returns classification scores 
      (or regression target) and decoder output (reconstruction of input)
    x: input
    targets: a list of targets associated with multi-head output specified by argument heads; 
      e.g., for an autoencoder with two heads, targets = [y_labels, x]
      targets are not needed to pair with all heads output one-to-one; 
      use arguments heads to specify which heads are paired with targets;
      The elements of targets can be None, too; 
      the length of targets must be compatible with that of loss_weights, loss_fns, and heads
    heads: the index for the heads paired with targets for calculating losses; 
      if None, set heads = list(range(len(targets)))
    loss_fns: a list of loss functions for the corresponding head
    loss_weights: the (non-negative) weights for the above head-losses;
      heads, loss_fns, and loss_weights are closely related to each other; need to handle it carefully
    other_loss_fns: a list of loss functions as regularizers on model parameters
    other_loss_weights: the corresponding weights for other_loss_fns
    return_loss: default True, return all losses
    batch_size: default None; split data into batches
    train: default True; if False, call model.eval() and torch.set_grad_enabled(False) to save time
    optimizer: when train is True, optimizer must be given; default None, do not use for evaluation
    epoch: for print only
    print_every: print epoch losses if epoch % print_every == 0
    verbose: if True, print losses for each batch
  """

  is_grad_enabled = torch.is_grad_enabled()
  if train:
    model.train()
    torch.set_grad_enabled(True)
  else:
    model.eval()
    torch.set_grad_enabled(False)
  if batch_size is None:
    batch_size = len(x)
  
  if len(targets) < len(loss_weights):
    # Some losses do not require targets (using 'implicit' targets in the objective)
    # Add None so that targets for later use
    targets = targets + [None]*(len(loss_weights) - len(targets))
  is_classification = [] # record the indices of targets that is for classification
  has_unequal_size = [] # record the indices of targets that has a different size with input
  is_none = [] # record the indices of the targets that is None
  for j, y_true in enumerate(targets):
    if y_true is not None:
      if len(y_true) == len(x):
        if isinstance(y_true.cpu(), torch.LongTensor):
          # if targets[j] is LongTensor, treat it as classification task
          is_classification.append(j)
      else:
        # Here is a bug: I use len(y_true)!=len(x) to decide y_true (target) is not 1-1 paired with input instances;
        # however, sometimes even if len(y_true)==len(x), y_true still may not be 1-1 paired with input instances;
        # since it rarely happens, I have not taken care of this bug
        has_unequal_size.append(j)
    else:
      is_none.append(j)
  loss_history = []
  if len(is_classification) > 0:
    acc_history = []

  if heads is None: # If head is not given, then assume the targets is paired with model output in order
    heads = list(range(len(targets)))
  for i in range(0, len(x), batch_size):
    y_pred = model(x[i:i+batch_size])
    loss_batch = []
    for j, w in enumerate(loss_weights):
      if w>0: # only execute when w>0
        if j in is_none:
          loss_j = loss_fns[j](y_pred[heads[j]]) * w
        elif j in has_unequal_size:
          loss_j = loss_fns[j](y_pred[heads[j]], targets[j]) * w # targets[j] is the same for all batches
        else:
          loss_j = loss_fns[j](y_pred[heads[j]], targets[j][i:i+batch_size]) * w
        loss_batch.append(loss_j)
    for j, w in enumerate(other_loss_weights):
      if w>0:
        # The implicit 'target' is encoded in the loss function itself
        # todo: in addition to argument model, make loss_fns handle other 'dynamic' arguments as well
        loss_j = other_loss_fns[j](model) * w 
        loss_batch.append(loss_j)
    loss = sum(loss_batch)
    loss_batch = [v.item() for v in loss_batch]
    loss_history.append(loss_batch)
    # Calculate accuracy
    if len(is_classification) > 0:
      acc_batch = []
      for k, j in enumerate(is_classification):
        labels_pred = y_pred[heads[j]].topk(1, -1)[1].squeeze()
        acc = (labels_pred == targets[j][i:i+batch_size]).float().mean().item()
        acc_batch.append(acc)
      acc_history.append(acc_batch)
    if verbose:
      msg = 'Epoch{} {}/{}: loss:{}'.format(epoch, i//batch_size, (len(x)+batch_size-1)//batch_size, 
        ', '.join(map(lambda x: f'{x:.2e}', loss_batch)))
      if len(is_classification) > 0:
        msg = msg + ', acc={}'.format(', '.join(map(lambda x: f'{x:.2f}', acc_batch)))
      print(msg)
    if train:
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
  torch.set_grad_enabled(is_grad_enabled)

  loss_epoch = np.mean(loss_history, axis=0)
  if len(is_classification) > 0:
    acc_epoch = np.mean(acc_history, axis=0)
  if epoch % print_every == 0:
    msg = 'Epoch{} {}: loss:{}'.format(epoch, 'Train' if train else 'Test', 
      ', '.join(map(lambda x: f'{x:.2e}', loss_epoch)))
    if len(is_classification) > 0:
      msg = msg + ', acc={}'.format(', '.join(map(lambda x: f'{x:.2f}', acc_epoch)))
    print(msg)
    
  if return_loss:
    if len(is_classification) > 0:
      return loss_epoch, acc_epoch, loss_history, acc_history
    else:
      return loss_epoch, loss_history
Exemplo n.º 22
0
def _criterion_parallel_apply(modules,
                              inputs,
                              targets,
                              kwargs_tup=None,
                              devices=None):
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({}, ) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)

    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, target, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input, )
                if not isinstance(target, (list, tuple)):
                    target = (target, )
                output = module(*(input + target), **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [
            threading.Thread(
                target=_worker,
                args=(i, module, input, target, kwargs, device),
            ) for i, (module, input, target, kwargs, device) in enumerate(
                zip(modules, inputs, targets, kwargs_tup, devices))
        ]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs
Exemplo n.º 23
0
    def forward_scriptable(
        self,
        src_tokens,
        src_lengths: Optional[torch.Tensor] = None,
        return_all_hiddens: bool = False,
        token_embeddings: Optional[torch.Tensor] = None,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).
            token_embeddings (torch.Tensor, optional): precomputed embeddings
                default `None` will recompute embeddings

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any(
        )

        x, encoder_embedding = self.forward_embedding(src_tokens,
                                                      token_embeddings)

        # account for padding while computing the representation
        if has_pads:
            x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        encoder_states = []
        fc_results = []

        if return_all_hiddens:
            encoder_states.append(x)

        # nested tensor and BT enable
        layer = self.layers[0]
        BT_flag = False
        NT_flag = False
        # torch version check, BT>=1.12.0 and NT>=1.13.0.dev20220613
        # internal format is '1.13.0a0+fb'
        # external format is '1.13.0.dev20220613'(cpu&gpu) for nightly or "1.11.0"(cpu) or '1.11.0+cu102'(gpu) for stable
        BT_version = False
        NT_version = False
        if "fb" in torch.__version__:
            BT_version = True
            NT_version = True
        else:
            if "+" in torch.__version__:
                torch_version = torch.__version__.split("+")[0]
            else:
                torch_version = torch.__version__

            torch_version = torch_version.split(".")
            int_version = (int(torch_version[0]) * 1000 +
                           int(torch_version[1]) * 10 + int(torch_version[2]))
            if len(torch_version) == 3:
                if int_version >= 1120:
                    BT_version = True
                if int_version >= 1131:
                    NT_version = True
            elif len(torch_version) == 4:
                if int_version >= 1130:
                    BT_version = True
                # Consider _nested_tensor_from_mask_left_aligned is landed after "20220613"
                if int_version >= 1131 or (int_version == 1130 and
                                           torch_version[3][3:] >= "20220613"):
                    NT_version = True

        if (BT_version and x.dim() == 3 and layer.load_to_BT
                and not layer.return_fc and layer.can_use_fastpath
                and not layer.training and not layer.ever_training
                and not layer.cfg_checkpoint_activations):
            # Batch first can not be justified but needs user to make sure
            x = x.transpose(0, 1)
            # Check mask conditions for nested tensor
            if NT_version:
                if (encoder_padding_mask is not None
                        and torch._nested_tensor_from_mask_left_aligned(
                            x, encoder_padding_mask.logical_not())):
                    if not torch.is_grad_enabled() or not x.requires_grad:
                        x = torch._nested_tensor_from_mask(
                            x, encoder_padding_mask.logical_not())
                        NT_flag = True
            BT_flag = True

        # encoder layers
        if NT_flag:
            processing_mask = None
        else:
            processing_mask = encoder_padding_mask
        encoder_padding_mask_out = processing_mask if has_pads else None
        for layer in self.layers:
            lr = layer(x, encoder_padding_mask=encoder_padding_mask_out)

            if isinstance(lr, tuple) and len(lr) == 2:
                x, fc_result = lr
            else:
                x = lr
                fc_result = None

            if return_all_hiddens and not torch.jit.is_scripting():
                assert encoder_states is not None
                encoder_states.append(x)
                fc_results.append(fc_result)

        # change back to non-nested and Batch second
        if NT_flag:
            x = x.to_padded_tensor(0.0)

        if NT_flag or BT_flag:
            x = x.transpose(0, 1)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
        # `forward` so we use a dictionary instead.
        # TorchScript does not support mixed values so the values are all lists.
        # The empty list is equivalent to None.
        src_lengths = (src_tokens.ne(self.padding_idx).sum(
            dim=1, dtype=torch.int32).reshape(-1, 1).contiguous())
        return {
            "encoder_out": [x],  # T x B x C
            "encoder_padding_mask": [encoder_padding_mask],  # B x T
            "encoder_embedding": [encoder_embedding],  # B x T x C
            "encoder_states": encoder_states,  # List[T x B x C]
            "fc_results": fc_results,  # List[T x B x C]
            "src_tokens": [],
            "src_lengths": [src_lengths],
        }
Exemplo n.º 24
0
 def _is_training(self):
     return self._flattened_module.training and torch.is_grad_enabled()
Exemplo n.º 25
0
    def forward(self,
                src: Tensor,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
        why_not_sparsity_fast_path = ''
        if not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif self.training:
            why_not_sparsity_fast_path = "training is enabled"
        elif not self.self_attn.batch_first:
            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
        elif not self.self_attn._qkv_same_embed_dim:
            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
        elif not self.activation_relu_or_gelu:
            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
        elif not (self.norm1.eps == self.norm2.eps):
            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
        elif src_mask is not None:
            why_not_sparsity_fast_path = "src_mask is not supported for fastpath"
        elif src.is_nested and src_key_padding_mask is not None:
            why_not_sparsity_fast_path = "src_key_padding_mask is not supported with NestedTensor input for fastpath"
        elif self.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"

        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                self.self_attn.in_proj_weight,
                self.self_attn.in_proj_bias,
                self.self_attn.out_proj.weight,
                self.self_attn.out_proj.bias,
                self.norm1.weight,
                self.norm1.bias,
                self.norm2.weight,
                self.norm2.bias,
                self.linear1.weight,
                self.linear1.bias,
                self.linear2.weight,
                self.linear2.bias,
            )

            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif not all(
                (x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU"
            elif torch.is_grad_enabled() and any(x.requires_grad
                                                 for x in tensor_args):
                why_not_sparsity_fast_path = (
                    "grad is enabled and at least one of query or the "
                    "input/output projection weights or biases requires_grad")

            if not why_not_sparsity_fast_path:
                if not torch.jit.is_scripting():
                    return torch._transformer_encoder_layer_fwd(
                        src,
                        self.self_attn.embed_dim,
                        self.self_attn.num_heads,
                        self.self_attn.in_proj_weight,
                        self.self_attn.in_proj_bias,
                        self.self_attn.out_proj.weight,
                        self.self_attn.out_proj.bias,
                        self.activation_relu_or_gelu == 2,
                        self.norm_first,
                        self.norm1.eps,
                        self.norm1.weight,
                        self.norm1.bias,
                        self.norm2.weight,
                        self.norm2.bias,
                        self.linear1.weight,
                        self.linear1.bias,
                        self.linear2.weight,
                        self.linear2.bias,
                        # TODO: if src_mask and src_key_padding_mask merge to single 4-dim mask
                        src_mask
                        if src_mask is not None else src_key_padding_mask,
                        1 if src_key_padding_mask is not None else
                        0 if src_mask is not None else None,
                    )
                elif src_mask is None:
                    # hack until 9/26/2022 for TS jit compatibility window
                    return torch._transformer_encoder_layer_fwd(
                        src,
                        self.self_attn.embed_dim,
                        self.self_attn.num_heads,
                        self.self_attn.in_proj_weight,
                        self.self_attn.in_proj_bias,
                        self.self_attn.out_proj.weight,
                        self.self_attn.out_proj.bias,
                        self.activation_relu_or_gelu == 2,
                        self.norm_first,
                        self.norm1.eps,
                        self.norm1.weight,
                        self.norm1.bias,
                        self.norm2.weight,
                        self.norm2.bias,
                        self.linear1.weight,
                        self.linear1.bias,
                        self.linear2.weight,
                        self.linear2.bias,
                        src_mask
                        if src_mask is not None else src_key_padding_mask,
                    )

        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask,
                                   src_key_padding_mask)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x +
                           self._sa_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))

        return x
Exemplo n.º 26
0
    def forward(self,
                src: Tensor,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src
        convert_to_nested = False
        first_layer = self.layers[0]
        src_key_padding_mask_for_layers = src_key_padding_mask
        why_not_sparsity_fast_path = ''
        str_first_layer = "self.layers[0]"
        if not isinstance(first_layer, torch.nn.TransformerEncoderLayer):
            why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer"
        elif first_layer.norm_first:
            why_not_sparsity_fast_path = f"{str_first_layer}.norm_first was True"
        elif first_layer.training:
            why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
        elif not first_layer.self_attn.batch_first:
            why_not_sparsity_fast_path = f" {str_first_layer}.self_attn.batch_first was not True"
        elif not first_layer.self_attn._qkv_same_embed_dim:
            why_not_sparsity_fast_path = f"{str_first_layer}.self_attn._qkv_same_embed_dim was not True"
        elif not first_layer.activation_relu_or_gelu:
            why_not_sparsity_fast_path = f" {str_first_layer}.activation_relu_or_gelu was not True"
        elif not (first_layer.norm1.eps == first_layer.norm2.eps):
            why_not_sparsity_fast_path = f"{str_first_layer}.norm1.eps was not equal to {str_first_layer}.norm2.eps"
        elif not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif not self.enable_nested_tensor:
            why_not_sparsity_fast_path = "enable_nested_tensor was not True"
        elif src_key_padding_mask is None:
            why_not_sparsity_fast_path = "src_key_padding_mask was None"
        elif (((not hasattr(self, "mask_check")) or self.mask_check)
              and not torch._nested_tensor_from_mask_left_aligned(
                  src, src_key_padding_mask.logical_not())):
            why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
        elif output.is_nested:
            why_not_sparsity_fast_path = "NestedTensor input is not supported"
        elif mask is not None:
            why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
        elif first_layer.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"

        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                first_layer.self_attn.in_proj_weight,
                first_layer.self_attn.in_proj_bias,
                first_layer.self_attn.out_proj.weight,
                first_layer.self_attn.out_proj.bias,
                first_layer.norm1.weight,
                first_layer.norm1.bias,
                first_layer.norm2.weight,
                first_layer.norm2.bias,
                first_layer.linear1.weight,
                first_layer.linear1.bias,
                first_layer.linear2.weight,
                first_layer.linear2.bias,
            )

            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif not (src.is_cuda or 'cpu' in str(src.device)):
                why_not_sparsity_fast_path = "src is neither CUDA nor CPU"
            elif torch.is_grad_enabled() and any(x.requires_grad
                                                 for x in tensor_args):
                why_not_sparsity_fast_path = (
                    "grad is enabled and at least one of query or the "
                    "input/output projection weights or biases requires_grad")

            if (not why_not_sparsity_fast_path) and (src_key_padding_mask
                                                     is not None):
                convert_to_nested = True
                # simplify on or after on 8/16/2022 to unconditionally call with mask_check=False
                # we have established that either (1) the mask is OK with the check above,
                # or (2) that we don't need a mask check with mask_check=False in the init
                if not torch.jit.is_scripting():
                    output = torch._nested_tensor_from_mask(
                        output,
                        src_key_padding_mask.logical_not(),
                        mask_check=False)
                else:
                    # When scripting, make a simpler call until the FC bar passes on 8/16/2022
                    output = torch._nested_tensor_from_mask(
                        output, src_key_padding_mask.logical_not())
                src_key_padding_mask_for_layers = None

        for mod in self.layers:
            output = mod(output,
                         src_mask=mask,
                         src_key_padding_mask=src_key_padding_mask_for_layers)

        if convert_to_nested:
            output = output.to_padded_tensor(0.)

        if self.norm is not None:
            output = self.norm(output)

        return output
Exemplo n.º 27
0
 def __init__(self, mode):
     self.prev = torch.is_grad_enabled()
     torch._C.set_grad_enabled(mode)
Exemplo n.º 28
0
 def __init__(self, mode):
     self.prev = torch.is_grad_enabled()
     torch._C.set_grad_enabled(mode)
Exemplo n.º 29
0
 def post_forward(module: _BatchNorm, input: Tensor,
                  result: Tensor) -> None:
     if torch.is_grad_enabled():
         return
     module.track_running_stats = module._track_running_stats_backup
Exemplo n.º 30
0
    def forward(self, source=0.0, power=True, detector=None):
        """calculate the network's response to an applied source.

        Args:
            source (Tensor): The source tensor to calculate the response for.
            power (bool): Return detected power, otherwise return complex signal.
            detector (callable): Custom detector function to use to detect the signal.

        Returns:
            Tensor: The detected tensor with shape (t, w, s, b) or with
                shape (2, t, w, s, b) in the case of power=False (in that case, dimension
                0 contains the stacked real and imaginary part of the result)

        Note:
             The source tensor should have shape (t, w, s, b), with
               * t: the number of timesteps in the simulation environment.
               * w: the number of wavelengths in the simulation environment.
               * s: the number of sources in the network.
               * b: the number of unrelated input waveforms (the batch size).

             Alternatively, two of such tensors can be stacked together in dimension 0
             to represent the real and imaginary part of a complex tensor,
             resulting in a tensor of shape (2, t, w, s, b).

             Any lower dimensional tensor should have named dimensions to remove any
             ambiguity in the broadcasting rules. Dimensions of a tensor can be named
             with the ``.rename`` method of the PyTorch Tensor class.
             accepted dimension names are 'c', 't', 'w', 's', 'b'.

        """

        # reinitialize the network if the current environment does not correspond
        # to the previous environment
        if self.env is not current_environment() or torch.is_grad_enabled():
            self.initialize()

        source = self._handle_source(source)

        num_batches = source.shape[-1]

        detected = torch.zeros(
            (
                self.env.num_t,
                self.env.num_wl,
                self.num_detectors,
                num_batches,
            ),
            device=self.device,
        )
        if not power:
            detected = torch.stack([detected, detected], 0)

        ## Get new simulation buffer
        buffer = self._simulation_buffer(num_batches)

        # solve
        for i, t in enumerate(self.env.t):
            det, buffer = self.step(t, source[:, i], buffer)

            if power:
                detected[i] = torch.sum(det**2, 0)
            else:
                detected[:, i] = det

        if detector is not None:
            detected = detector(detected)

        return detected
Exemplo n.º 31
0
 def replicate(self, module, device_ids):
     return replicate(module, device_ids, not torch.is_grad_enabled())
Exemplo n.º 32
0
def inspect(model,
            X,
            frame_rate=4,
            insp_keys={},
            batch_size=None,
            to_numpy=True,
            verbose=False):
    """
    Get the response from the argued layers in the model as np arrays. If model is on cpu,
    operations are performed on cpu. Put model on gpu if you desire operations to be
    performed on gpu.

    model - torch Module or torch gpu Module
    X - ndarray (L,T,C,H,W)
    insp_keys - set of str
        name of layers activations to collect
    to_numpy - bool
        if true, activations will all be ndarrays. Otherwise torch tensors
    to_cpu - bool
        if true, torch tensors will be on the cpu.
        only effective if to_numpy is false.
    batch_size: int
        batching performed over L dimension in X

    returns dict of np arrays or torch cpu tensors
    """
    layer_outs = dict()
    handles = []
    if "all" in insp_keys:
        for i in range(len(model.sequential)):
            key = "sequential." + str(i)
            hook = get_hook(layer_outs, key, to_numpy=to_numpy, to_cpu=True)
            handle = model.sequential[i].register_forward_hook(hook)
            handles.append(handle)
    else:
        for key, mod in model.named_modules():
            if key in insp_keys:
                hook = get_hook(layer_outs,
                                key,
                                to_numpy=to_numpy,
                                to_cpu=True)
                handle = mod.register_forward_hook(hook)
                handles.append(handle)
    X = torch.FloatTensor(X)

    # prev_grad_state is used to ensure we do not mess with an outer "with torch.no_grad():"
    prev_grad_state = torch.is_grad_enabled()
    if to_numpy:
        # Turns off all gradient calculations. When returning numpy arrays, the computation
        # graph is inaccessible, as such we do not need to calculate it.
        torch.set_grad_enabled(False)

    if batch_size is None:
        if next(model.parameters()).is_cuda:
            X = X.to(DEVICE)
        x_len = len(X)
        h = model.init_h(x_len)  # Recurrent state vector (B, Z)
        for t in range(0, X.shape[1], frame_rate):
            x = X[:, t].to(DEVICE)  # (B, D, H, W)
            h = model(x, h)
        preds = model.classify(h).squeeze()
        if to_numpy:
            layer_outs['outputs'] = preds.detach().cpu().numpy()
        else:
            layer_outs['outputs'] = preds.cpu()
    else:
        use_cuda = next(model.parameters()).is_cuda
        batched_outs = {key: [] for key in insp_keys}
        outputs = []
        rnge = range(0, len(X), batch_size)
        if verbose:
            rnge = tqdm(rnge)
        for batch in rnge:
            X_ = X[batch:batch + batch_size]
            if use_cuda:
                X_ = X_.to(DEVICE)
            x_len = len(X_)
            h = model.init_h(x_len)  # Recurrent state vector (B, Z)
            for t in range(0, X.shape[1], frame_rate):
                x = X_[:, t].to(DEVICE)  # (B, D, H, W)
                h = model(x, h)
            preds = model.classify(h).squeeze()
            if to_numpy:
                preds = preds.detach().numpy()
            outputs.append(preds)
            for k in layer_outs.keys():
                batched_outs[k].append(layer_outs[k])
        batched_outs['outputs'] = outputs
        if to_numpy:
            layer_outs = {
                k: np.concatenate(v, axis=0)
                for k, v in batched_outs.items()
            }
        else:
            layer_outs = {
                k: torch.cat(v, dim=0)
                for k, v in batched_outs.items()
            }

    # If we turned off the grad state, this will turn it back on. Otherwise leaves it the same.
    torch.set_grad_enabled(prev_grad_state)

    # This for loop ensures we do not create a memory leak when using hooks
    for i in range(len(handles)):
        handles[i].remove()
    del handles

    return layer_outs
Exemplo n.º 33
0
 def forward(self, activations, input_lengths, labels, label_lengths):
     return CTCFunction.apply(activations, input_lengths, labels,
                              label_lengths, self.reduce, self.size_average,
                              self.length_average, self.blank_label,
                              torch.is_grad_enabled())
Exemplo n.º 34
0
 def forward(self, x):
     test.assertFalse(torch.is_grad_enabled())
     return x
Exemplo n.º 35
0
    def apply(u,
              x,
              weight_c,
              bias,
              init,
              activation_type,
              d,
              bidirectional,
              has_skip_term,
              scale_x,
              mask_c=None,
              mask_pad=None):
        """
        An SRU is a recurrent neural network cell comprised of 5 equations, described
        in "Simple Recurrent Units for Highly Parallelizable Recurrence."

        The first 3 of these equations each require a matrix-multiply component,
        i.e. the input vector x_t dotted with a weight matrix W_i, where i is in
        {0, 1, 2}.

        As each weight matrix W is dotted with the same input x_t, we can fuse these
        computations into a single matrix-multiply, i.e. `x_t <dot> stack([W_0, W_1, W_2])`.
        We call the result of this computation `U`.

        sru_compute_cpu() accepts 'u' and 'x' (along with a tensor of biases,
        an initial memory cell `c0`, and an optional dropout mask) and computes
        equations (3) - (7). It returns a tensor containing all `t` hidden states
        (where `t` is the number of elements in our input sequence) and the final
        memory cell `c_T`.
        """

        bidir = 2 if bidirectional else 1
        length = x.size(0) if x.dim() == 3 else 1
        batch = x.size(-2)
        k = u.size(-1) // d // bidir

        is_custom = len(weight_c.size()) > 1

        sru_cpu_impl = _lazy_load_cpu_kernel()
        if (sru_cpu_impl is not None) and (sru_cpu_impl != False):
            if not torch.is_grad_enabled():
                assert mask_c is None
                cpu_forward = sru_cpu_impl.cpu_bi_forward if bidirectional else \
                              sru_cpu_impl.cpu_forward
                mask_pad_ = torch.FloatTensor(
                ) if mask_pad is None else mask_pad.float()
                return cpu_forward(
                    u.contiguous(), x.contiguous(), weight_c.contiguous(),
                    bias, init, mask_pad_, length, batch, d, k,
                    activation_type, has_skip_term,
                    scale_x.item() if scale_x is not None else 1.0, is_custom)
            else:
                warnings.warn(
                    "Running SRU on CPU with grad_enabled=True. Are you sure?")
        else:
            warnings.warn("C++ kernel for SRU CPU inference was not loaded. "
                          "Use Python version instead.")

        mask_pad_ = mask_pad.view(
            length, batch, 1).float() if mask_pad is not None else mask_pad
        u = u.contiguous().view(length, batch, bidir, d, k)

        if is_custom:
            weight_c = weight_c.view(length, batch, bidir, d, 2)
            forget_wc = weight_c[..., 0]
            reset_wc = weight_c[..., 1]
        else:
            forget_wc, reset_wc = weight_c.view(2, bidir, d)

        forget_bias, reset_bias = bias.view(2, bidir, d)

        if not has_skip_term:
            x_prime = None
        elif k == 3:
            x_prime = x.view(length, batch, bidir, d)
            x_prime = x_prime * scale_x if scale_x is not None else x_prime
        else:
            x_prime = u[..., 3]

        h = x.new_zeros(length, batch, bidir, d)

        if init is None:
            c_init = x.new_zeros(size=(batch, bidir, d))
        else:
            c_init = init.view(batch, bidir, d)

        c_final = []
        for di in range(bidir):
            if di == 0:
                time_seq = range(length)
            else:
                time_seq = range(length - 1, -1, -1)

            mask_c_ = 1 if mask_c is None else mask_c.view(batch, bidir,
                                                           d)[:, di, :]
            c_prev = c_init[:, di, :]
            fb, rb = forget_bias[di], reset_bias[di]
            if is_custom:
                fw = forget_wc[:, :, di, :].chunk(length)
                rw = reset_wc[:, :, di, :].chunk(length)
            else:
                fw = forget_wc[di].expand(batch, d)
                rw = reset_wc[di].expand(batch, d)
            u0 = u[:, :, di, :, 0].chunk(length)
            u1 = (u[:, :, di, :, 1] + fb).chunk(length)
            u2 = (u[:, :, di, :, 2] + rb).chunk(length)
            if x_prime is not None:
                xp = x_prime[:, :, di, :].chunk(length)

            for t in time_seq:
                if is_custom:
                    forget_t = (u1[t] + c_prev * fw[t]).sigmoid()
                    reset_t = (u2[t] + c_prev * rw[t]).sigmoid()
                else:
                    forget_t = (u1[t] + c_prev * fw).sigmoid()
                    reset_t = (u2[t] + c_prev * rw).sigmoid()
                c_t = u0[t] + (c_prev - u0[t]) * forget_t
                if mask_pad_ is not None:
                    c_t = c_t * (1 - mask_pad_[t]) + c_prev * mask_pad_[t]
                c_prev = c_t

                if activation_type == 0:
                    g_c_t = c_t
                elif activation_type == 1:
                    g_c_t = c_t.tanh()
                else:
                    raise ValueError(
                        'Activation type must be 0 or 1, not {}'.format(
                            activation_type))

                if x_prime is not None:
                    h_t = xp[t] + (g_c_t - xp[t]) * mask_c_ * reset_t
                else:
                    h_t = g_c_t * mask_c_ * reset_t
                if mask_pad_ is not None:
                    h_t = h_t * (1 - mask_pad_[t])
                h[t, :, di, :] = h_t

            c_final.append(c_t.view(batch, d))
        return h.view(length, batch, -1), torch.stack(c_final,
                                                      dim=1).view(batch, -1)
Exemplo n.º 36
0
 def __init__(self):
     self.prev = torch.is_grad_enabled()
Exemplo n.º 37
0
 def _is_training(self):
     return self.training and torch.is_grad_enabled()