Exemplo n.º 1
0
 def decorate_fwd(*args, **kwargs):
     if cast_inputs is None:
         args[0]._fwd_used_autocast = torch.is_autocast_enabled()
         return fwd(*args, **kwargs)
     else:
         autocast_context = torch.is_autocast_enabled()
         args[0]._fwd_used_autocast = False
         if autocast_context:
             with autocast(enabled=False):
                 return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
         else:
             return fwd(*args, **kwargs)
Exemplo n.º 2
0
    def forward(self,
                graph: DGLGraph,
                node_feats: Dict[str, Tensor],
                edge_feats: Optional[Dict[str, Tensor]] = None,
                basis: Optional[Dict[str, Tensor]] = None):
        # Compute bases in case they weren't precomputed as part of the data loading
        basis = basis or get_basis(graph.edata['rel_pos'],
                                   max_degree=self.max_degree,
                                   compute_gradients=False,
                                   use_pad_trick=self.tensor_cores
                                   and not self.low_memory,
                                   amp=torch.is_autocast_enabled())

        # Add fused bases (per output degree, per input degree, and fully fused) to the dict
        basis = update_basis_with_fused(
            basis,
            self.max_degree,
            use_pad_trick=self.tensor_cores and not self.low_memory,
            fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL)

        edge_feats = get_populated_edge_features(graph.edata['rel_pos'],
                                                 edge_feats)

        node_feats = self.graph_modules(node_feats,
                                        edge_feats,
                                        graph=graph,
                                        basis=basis)

        if self.pooling is not None:
            return self.pooling_module(node_feats, graph=graph)

        if self.return_type is not None:
            return node_feats[str(self.return_type)]

        return node_feats
Exemplo n.º 3
0
def criterion_parallel_apply(modules, inputs, targets, devices, kwargs_tup=None):

	if kwargs_tup is None:
		kwargs_tup = ({},) * len(modules)

	lock = Lock()
	results = {}
	grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()

	def _worker(i, module, input, target, kwargs, device):

		if not isinstance(input, (list, tuple)):
			input = (input,)
		if not isinstance(target, (list, tuple)):
			target = (target,)
		with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled):
			output = module(*(input + target), **kwargs)
		with lock:
			results[i] = output

	threads = [Thread(target=_worker, args=(i, module, input, target, kwargs, device)) for i, (module, input, target, kwargs, device) in enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

	for thread in threads:
		thread.start()
	for thread in threads:
		thread.join()

	outputs = []
	for i in range(len(inputs)):
		output = results[i]
		outputs.append(output)
	return outputs
Exemplo n.º 4
0
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(
                    *args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():
            outputs = run_function(*args)
        return outputs
Exemplo n.º 5
0
    def __enter__(self):
        if torch._jit_internal.is_scripting():
            assert self.fast_dtype is not None
            return self

        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
        if self.device == 'cpu':
            self.prev = torch.is_autocast_cpu_enabled()
            self.prev_fastdtype = torch.get_autocast_cpu_dtype()
            torch.set_autocast_cpu_enabled(self._enabled)
            torch.set_autocast_cpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.autocast_increment_nesting()
        elif self.device == 'xpu':
            self.prev = torch.xpu.is_autocast_xpu_enabled(
            )  # type: ignore[attr-defined]
            self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype(
            )  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_enabled(
                self._enabled)  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_dtype(
                self.fast_dtype)  # type: ignore[attr-defined]
            torch.autocast_increment_nesting()
        else:
            self.prev = torch.is_autocast_enabled()
            self.prev_fastdtype = torch.get_autocast_gpu_dtype()
            torch.set_autocast_gpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.set_autocast_enabled(self._enabled)
            torch.autocast_increment_nesting()
        torch.set_autocast_cache_enabled(self._cache_enabled)
Exemplo n.º 6
0
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        ctx.has_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.device_states = get_device_states(*args)

        def replace_tensors(arg):
            if torch.is_tensor(arg):
                return TensorPlaceholder(None)
            return arg

        tensor_inputs = []
        ctx.inputs = recursive_apply(replace_tensors, args)
        for input_arg, all_arg in zip(recursive_walk(ctx.inputs),
                                      recursive_walk(args)):
            if isinstance(input_arg, TensorPlaceholder):
                assert torch.is_tensor(all_arg)
                assert input_arg.tensor_index is None
                tensor_inputs.append(all_arg)
                input_arg.tensor_index = len(tensor_inputs) - 1
        ctx.save_for_backward(*tensor_inputs)
        with torch.no_grad():
            outputs = run_function(*args)
        for output in recursive_walk(outputs):
            if torch.is_tensor(output):
                output.requires_grad_(True)
        return outputs
Exemplo n.º 7
0
def batch_norm(g, input, weight, bias, running_mean, running_var, training,
               momentum, eps, cudnn_enabled):

    if torch.is_autocast_enabled() and \
            not args_have_same_dtype([input, weight, bias, running_mean, running_var]) and \
            sym_help._export_onnx_opset_version < 15:
        return sym_help._onnx_opset_unsupported_detailed(
            "BatchNormalization", 14, 15,
            "All input tensors must have the same `dtype`."
            " Turn off Autocast or export using opset version 15.")

    sym_help.check_training_mode(training, "batch_norm")
    weight, bias, running_mean, running_var = sym_help._batchnorm_helper(
        g, input, weight, bias, running_mean, running_var)
    out = g.op("BatchNormalization",
               input,
               weight,
               bias,
               running_mean,
               running_var,
               epsilon_f=eps,
               momentum_f=1 - momentum,
               training_mode_i=0 if not training else 1,
               outputs=1 if not training else 3)
    if not training:
        return out
    else:
        res, new_running_mean, new_running_var = out
        new_running_mean.setType(running_mean.type())
        new_running_var.setType(running_var.type())
        return res
Exemplo n.º 8
0
 def fn(x):
     if torch.is_autocast_enabled():
         y = True
     else:
         y = False
     with torch.cuda.amp.autocast(enabled=True):
         z = x.relu()
     return y, z
Exemplo n.º 9
0
 def forward(self, *args, **kwargs):
     if torch.is_autocast_enabled():
         with torch.cuda.amp.autocast(enabled=False):
             result = self.module(*tensor_to(args, torch.float32),
                                  **tensor_to(kwargs, torch.float32))
     else:
         result = self.module(*args, **kwargs)
     return result
Exemplo n.º 10
0
 def fix_inf(self, x):
     if not self.detect_inf:
         return x
     # to enable fp16 training
     is_fp16 = x.dtype == torch.float16 or torch.is_autocast_enabled()
     if is_fp16 and torch.isinf(x).any():
         clamp_value = torch.finfo(torch.float16).max - 1000
         x = torch.clamp(x, min=-clamp_value, max=clamp_value)
     return x
Exemplo n.º 11
0
    def on_batch_start(self, runner: IRunner) -> None:
        """On batch start event

        Args:
            runner: current runner
        """
        self.prev_autocast_state = torch.is_autocast_enabled()
        torch.set_autocast_enabled(True)
        torch.autocast_increment_nesting()
Exemplo n.º 12
0
def _get_autocast_kwargs():
    gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
                           "dtype": torch.get_autocast_gpu_dtype(),
                           "cache_enabled": torch.is_autocast_cache_enabled()}

    cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
                           "dtype": torch.get_autocast_cpu_dtype(),
                           "cache_enabled": torch.is_autocast_cache_enabled()}

    return gpu_autocast_kwargs, cpu_autocast_kwargs
Exemplo n.º 13
0
 def __enter__(self):
     if self.device == 'cpu':
         self.prev = torch.is_autocast_cpu_enabled()
         self.prev_fastdtype = torch.get_autocast_cpu_dtype()
         torch.set_autocast_cpu_enabled(self._enabled)
         torch.set_autocast_cpu_dtype(self.fast_dtype)
         torch.autocast_increment_nesting()
     else:
         self.prev = torch.is_autocast_enabled()
         self.prev_fastdtype = torch.get_autocast_gpu_dtype()
         torch.set_autocast_gpu_dtype(self.fast_dtype)
         torch.set_autocast_enabled(self._enabled)
         torch.autocast_increment_nesting()
Exemplo n.º 14
0
    def _validate_inputs(self, a, b):
        if a.device != b.device:
            raise ValueError(
                f"Inputs must be on the same device; got {a.device} for tensor A "
                f"and {b.device} for tensor B")
        if not a.is_cuda:
            raise ValueError("Only GPU devices are supported for now")

        # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
        if torch.is_autocast_enabled():
            a, b = a.half(), b.half()
        elif a.dtype != b.dtype:
            raise ValueError(
                f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B"
            )

        mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
        if mode != 'sdd':
            # One input is sparse
            dense, dense_name, sparse, sparse_name = (
                a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
            dense_inner = dense.shape[self.dense_inner_dim]
            if dense_inner != self.dense_inner_size:
                raise ValueError(
                    f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
                    f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")

            if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
                raise ValueError(
                    f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
                    f"{sparse_name}, got {sparse.shape}")

        def add_extra_dims(x):
            # Add extra leading singleton dimensions if needed
            dims_needed = 4 - x.ndim
            if dims_needed > 0:
                singletons = [1] * dims_needed
                x = x.view(*singletons, *x.shape)
            elif dims_needed < 0:
                raise ValueError(
                    "Tensors with more than 4 dimensions are not currently supported"
                )

            return x

        # Pad shapes with leading singleton dimensions
        a = add_extra_dims(a)
        b = add_extra_dims(b)

        return a, b
def causal_linear_attention(q, k, v, eps = 1e-6):
    from fast_transformers.causal_product import CausalDotProduct
    autocast_enabled = torch.is_autocast_enabled()
    is_half = isinstance(q, torch.cuda.HalfTensor)
    assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available'
    cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False)

    causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply

    k_cumsum = k.cumsum(dim=-2) + eps
    D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))

    with cuda_context():
        if autocast_enabled:
            q, k, v = map(lambda t: t.float(), (q, k, v))

        out = causal_dot_product_fn(q, k, v)

    out = torch.einsum('...nd,...n->...nd', out, D_inv)
    return out
Exemplo n.º 16
0
 def forward(ctx, run_function, preserve_rng_state, *args):
     check_backward_validity(args)
     ctx.run_function = run_function
     ctx.preserve_rng_state = preserve_rng_state
     ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
     if preserve_rng_state:
         ctx.fwd_cpu_state = torch.get_rng_state()
         # Don't eagerly initialize the cuda context by accident.
         # (If the user intends that the context is initialized later, within their
         # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
         # we have no way to anticipate this will happen before we run the function.)
         ctx.had_cuda_in_fwd = False
         if torch.cuda._initialized:
             ctx.had_cuda_in_fwd = True
             ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(
                 *args)
     ctx.save_for_backward(*args)
     with torch.no_grad():
         outputs = run_function(*args)
     return outputs
Exemplo n.º 17
0
def attention(query, key, value, mask=None, key_padding_mask=None, dropout=None,fixed=False,att_bias=None):
    #print(query.size()) #batch,heads,len,feats
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    #bsz, num_heads, src_len,d_k = key.size()
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    ###
    #scores.fill_(0.1)
    ###
    #scores.size() = (batch heads queries keys)
    if att_bias is not None:
        scores+=att_bias
    if mask is not None:
        if torch.is_autocast_enabled():
            scores = scores.masked_fill(mask == 0, -1e4)
        else:
            scores = scores.masked_fill(mask == 0, -1e9)
    if key_padding_mask is not None:
        #import pdb;pdb.set_trace()
        #tgt_len = query.size(2)
        #scores = scores.view(bsz, num_heads, tgt_len, src_len
        scores = scores.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2),
            float("-inf"),
        )
        #scores = scores.view(bsz * num_heads, tgt_len, src_len)

    p_attn = F.softmax(scores, dim = -1)
    
    if mask is not None and fixed:
        p_attn = p_attn.masked_fill(mask == 0, 0) #this is needed in case a node has no neigbors (softmax gives even attention to everything
        #will create a zero vector in those cases, instead of the average of all nodes
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn
Exemplo n.º 18
0
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
    if not torch.is_autocast_enabled():
        return torch.float or dtype
    else:
        return torch.get_autocast_gpu_dtype()
Exemplo n.º 19
0
    def _run_autocast_outofplace(self,
                                 op,
                                 args,
                                 run_as_type,
                                 out_type=None,
                                 module=torch,
                                 add_kwargs=None):
        # helper to cast args
        def cast(val, to_type):
            if isinstance(val, torch.Tensor):
                return val.to(to_type) if val.is_floating_point() else val
            elif isinstance(val, collections.abc.Iterable):
                return type(val)(cast(v, to_type) for v in val)
            else:
                return val

        if add_kwargs is None:
            add_kwargs = {}

        self.assertFalse(torch.is_autocast_enabled())
        with autocast():
            self.assertTrue(torch.is_autocast_enabled())

            out_type = out_type if out_type is not None else run_as_type
            output = output_method = None

            # Try module.* variant, if requested:
            if module is not None and hasattr(module, op):
                output = getattr(module, op)(*args, **add_kwargs)
                if isinstance(output, torch.Tensor):
                    self.assertTrue(
                        out_type == output.dtype,
                        "autocast for torch.{} produced {}, should produce {}".
                        format(op, output.dtype, out_type))

            # Try Tensor.* variant:
            if hasattr(torch.Tensor, op):
                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
                if isinstance(output_method, torch.Tensor):
                    self.assertTrue(
                        out_type == output_method.dtype,
                        "autocast for torch.{} produced {}, should produce torch.{}"
                        .format(op, output_method.dtype, out_type))

            self.assertTrue((output is not None) or (
                output_method is not None
            ), "{} not found as an attribute on either Tensor or the requested module {}"
                            .format(op, module))

            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
            # For example, lstm_cell returns a tuple and equal returns bool.
            def compare(first, second):
                if isinstance(first, torch.Tensor):
                    return torch.equal(first, second)
                elif isinstance(first, collections.abc.Iterable):
                    return all(compare(f, s) for f, s in zip(first, second))
                else:
                    return first == second

            # If both torch.* and Tensor.* variants were found, check outputs are identical
            if (output is not None) and (output_method is not None):
                self.assertTrue(type(output) == type(output_method))
                comparison = compare(output, output_method)
                self.assertTrue(
                    comparison,
                    "torch.{0} result did not match Tensor.{0} result".format(
                        op))

            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
            # as the C++-side autocasting, and should be bitwise accurate.
            output_to_compare = output if output is not None else output_method
            with autocast(enabled=False):
                self.assertFalse(torch.is_autocast_enabled())

                if module is not None and hasattr(module, op):
                    control = getattr(module, op)(*cast(args, run_as_type),
                                                  **add_kwargs)
                else:
                    control = getattr(args[0].to(run_as_type),
                                      op)(*cast(args[1:], run_as_type),
                                          **add_kwargs)
                self.assertTrue(type(output_to_compare) == type(control))
                comparison = compare(output_to_compare, control)
                self.assertTrue(
                    comparison,
                    "torch.{} result did not match control".format(op))
            self.assertTrue(torch.is_autocast_enabled())
        self.assertFalse(torch.is_autocast_enabled())
Exemplo n.º 20
0
 def __enter__(self):
     self.prev = torch.is_autocast_enabled()
     torch.set_autocast_enabled(self._enabled)
     torch.autocast_increment_nesting()
Exemplo n.º 21
0
def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args):
    """Checkpointining without re-entrant autograd
    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.
        *args: Arguments to pass in to the given ``function``.
    """
    had_autocast_in_fwd = torch.is_autocast_enabled()

    if preserve_rng_state:
        fwd_cpu_state = torch.get_rng_state()
        # Don't eagerly initialize the cuda context by accident.
        # (If the user intends that the context is initialized later, within their
        # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
        # we have no way to anticipate this will happen before we run the function.
        # If they do so, we raise an error.)
        had_cuda_in_fwd = False
        if torch.cuda._initialized:
            had_cuda_in_fwd = True
            fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)

    storage: List[Union[torch.Tensor, None]] = []
    counter = 0

    def pack(x):
        nonlocal counter
        counter += 1
        # TODO(varal7): Instead of returning indices, we can return things metadata (such as
        # size, device, ...) to catch certain cases of undeterministic behavior of the forward
        return counter - 1

    def unpack(x):
        if len(storage) == 0:

            def inner_pack(inner):
                storage.append(inner)
                return None

            def inner_unpack(packed):
                raise RuntimeError(
                    "You are calling backwards on a tensor that is never exposed. Please open an issue."
                )

            # Stash the surrounding rng state, and mimic the state that was
            # present at this time during forward.  Restore the surrounding state
            # when we're done.
            rng_devices = []
            if preserve_rng_state and had_cuda_in_fwd:
                rng_devices = fwd_gpu_devices
            with torch.random.fork_rng(devices=rng_devices,
                                       enabled=preserve_rng_state):
                if preserve_rng_state:
                    torch.set_rng_state(fwd_cpu_state)
                    if had_cuda_in_fwd:
                        set_device_states(fwd_gpu_devices, fwd_gpu_states)
                with torch.enable_grad(), torch.cuda.amp.autocast(
                        had_autocast_in_fwd):
                    with torch.autograd.graph.saved_tensors_hooks(
                            inner_pack, inner_unpack):
                        _unused = function(*args)

        return storage[x]

    with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
        output = function(*args)
        if torch.cuda._initialized and not had_cuda_in_fwd:
            # Cuda was not initialized before running the forward, so we didn't
            # stash the CUDA state.
            raise RuntimeError(
                "PyTorch's CUDA state was initialized in the forward pass "
                "of a Checkpoint, which is not allowed. Please open an issue "
                "if you need this feature.")

    return output
Exemplo n.º 22
0
def _cast_if_autocast_enabled(*args):
    if not torch.is_autocast_enabled():
        return args
    else:
        return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
Exemplo n.º 23
0
def is_autocast_enabled() -> bool:
    """Similar to torch.is_autocast_enabled, but compatible with torch 1.5.1"""
    if hasattr(torch, "is_autocast_enabled"):
        return torch.is_autocast_enabled()
    return False
Exemplo n.º 24
0
 def forward(self, x):
     """Forward pass"""
     return self.mod.forward(x.to(torch.float).to(
         x.dtype)) if torch.is_autocast_enabled() else self.mod.forward(x)
Exemplo n.º 25
0
 def _assert_autocast_enabled(self):
     if self.trainer.precision_plugin.device == "cpu":
         assert torch.is_autocast_cpu_enabled()
     else:
         assert torch.is_autocast_enabled()
Exemplo n.º 26
0
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.

    Args:
        modules (Module): modules to be parallelized
        inputs (tensor): inputs to the modules
        devices (list of int or torch.device): CUDA devices

    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    assert len(modules) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({}, ) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled, autocast_enabled = torch.is_grad_enabled(
    ), torch.is_autocast_enabled()

    def _worker(i, module, input, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device), autocast(enabled=autocast_enabled):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input, )
                output = module(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception:
            with lock:
                results[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))

    if len(modules) > 1:
        threads = [
            threading.Thread(target=_worker,
                             args=(i, module, input, kwargs, device))
            for i, (
                module, input, kwargs,
                device) in enumerate(zip(modules, inputs, kwargs_tup, devices))
        ]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
            output.reraise()
        outputs.append(output)
    return outputs
Exemplo n.º 27
0
 def _step(self, batch, batch_idx):
     assert torch.is_autocast_enabled()
     output = self(batch)
     assert output.dtype == torch.float16
     loss = self.loss(batch, output)
     return loss
Exemplo n.º 28
0
 def check_autocast(forward_input):
     assert precision != 16 or torch.is_autocast_enabled()
     return forward_input
Exemplo n.º 29
0
 def predict(self, batch, batch_idx, dataloader_idx=None):
     assert torch.is_autocast_enabled()
     output = self(batch)
     assert output.dtype == torch.float16
     return output
def to_dtype(x, tensor=None, dtype=None):
    if not torch.is_autocast_enabled():
        dt = dtype if dtype is not None else tensor.dtype
        if x.dtype != dt:
            x = x.type(dt)
    return x