示例#1
0
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(
                    *args)
        ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*args)
        # return outputs

        #
        # Lie to torch we have no None items, to avoid the assert
        #
        result = []
        for o in outputs:
            if o is None:
                o = torch.zeros(0)
            result.append(o)

        return tuple(result)
示例#2
0
    def _init_feedforward_seed(self, *args):
        """
                    This function sets a new seed for the
                    feed forward layer to make dropout deterministic
                    for both forward calls: 1 normal forward
                    call and 1 forward call in backward
                    to recalculate activations.
                """

        self.ffn_cpu_state = torch.get_rng_state()
        self.ffn_gpu_devices, self.ffn_gpu_states = get_device_states(*args)
示例#3
0
    def forward(ctx, rev_block_stack, preserve_rng_state, *inputs):
        '''
        :param ctx:                 context, like self
        :param rev_block_stack:     Module with multiple reversible blocks stacked, such as RevSequential
        :param preserve_rng_state:  Whether to save the random number state, can be used to reproduce the random number
                                    only torch random numbers, numpy does not include
        :param inputs:              Multiple tensor, requires that at least one tensor requires_grad is True,
                                    otherwise this function will not calculate backpropagation,
                                     which is a limitation from torch.autograd.Function
        :return:
        '''
        # Warn when requires_grad of all inputs tensor is not True
        check_backward_validity(inputs)
        # Make sure the input is a list of modules and supports invert operations
        assert isinstance(rev_block_stack, nn.ModuleList)
        assert hasattr(rev_block_stack, 'inverse') and callable(
            rev_block_stack.inverse)

        ctx.rev_block_stack = rev_block_stack
        ctx.preserve_rng_state = preserve_rng_state

        # rng state save
        # Note that the state should be saved and restored layer by layer
        if preserve_rng_state:
            ctx.status_stack = []
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.cuda_status_stack = []

            # Since the execution order of the modules is reversed when the back pass is required,
            # each sub-module needs to save the random number state separately.
            outputs = inputs
            for m in ctx.rev_block_stack:
                fwd_cpu_state = torch.get_rng_state()
                ctx.status_stack.append(fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    fwd_gpu_devices, fwd_gpu_states = get_device_states(
                        *outputs)
                    ctx.cuda_status_stack.append(
                        [fwd_gpu_devices, fwd_gpu_states])
                # Set torch.no_grad because don't save intermediate variables
                with torch.no_grad():
                    outputs = m(*outputs)

        else:
            # If you don't need to save the random number state, you can run the entire module directly to get the output
            with torch.no_grad():
                outputs = rev_block_stack(*inputs)

        # Save output for backward function
        ctx.save_for_backward(*outputs)
        return outputs
示例#4
0
    def forward(ctx, fn, fn_inverse, keep_input, num_bwd_passes,
                preserve_rng_state, num_inputs, *inputs_and_weights):
        # store in context
        ctx.fn = fn
        ctx.fn_inverse = fn_inverse
        ctx.keep_input = keep_input
        ctx.weights = inputs_and_weights[num_inputs:]
        ctx.num_bwd_passes = num_bwd_passes
        ctx.preserve_rng_state = preserve_rng_state
        ctx.num_inputs = num_inputs
        inputs = inputs_and_weights[:num_inputs]

        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(
                    *inputs)

        ctx.input_requires_grad = [element.requires_grad for element in inputs]

        with torch.no_grad():
            # Makes a detached copy which shares the storage
            x = []
            for element in inputs:
                if isinstance(element, torch.Tensor):
                    x.append(element.detach())
                else:
                    x.append(element)
            outputs = ctx.fn(*x)

        if not isinstance(outputs, tuple):
            outputs = (outputs, )

        # Detaches y in-place (inbetween computations can now be discarded)
        detached_outputs = tuple([element.detach_() for element in outputs])

        # clear memory from inputs
        # only clear memory of node features
        if not ctx.keep_input:
            # PyTorch 1.0+ way to clear storage for node features
            inputs[0].storage().resize_(0)

        # store these tensor nodes for backward pass
        ctx.inputs = [inputs] * num_bwd_passes
        ctx.outputs = [detached_outputs] * num_bwd_passes

        return detached_outputs
示例#5
0
文件: checkpoint.py 项目: pytorch/xla
  def forward(ctx, run_function, preserve_rng_state, *args):
    check_backward_validity(args)
    ctx.run_function = run_function
    ctx.preserve_rng_state = preserve_rng_state
    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
    ctx.gpu_autocast_kwargs = {
        "enabled": torch.is_autocast_enabled(),
        "dtype": torch.get_autocast_gpu_dtype(),
        "cache_enabled": torch.is_autocast_cache_enabled()
    }
    ctx.cpu_autocast_kwargs = {
        "enabled": torch.is_autocast_cpu_enabled(),
        "dtype": torch.get_autocast_cpu_dtype(),
        "cache_enabled": torch.is_autocast_cache_enabled()
    }
    if preserve_rng_state:
      ctx.fwd_cpu_state = torch.get_rng_state()
      # Don't eagerly initialize the cuda context by accident.
      # (If the user intends that the context is initialized later, within their
      # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
      # we have no way to anticipate this will happen before we run the function.)
      ctx.had_cuda_in_fwd = False
      if torch.cuda._initialized:
        ctx.had_cuda_in_fwd = True
        ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

    # Save non-tensor inputs in ctx, keep a placeholder None for tensors
    # to be filled out during the backward.
    ctx.inputs = []
    ctx.tensor_indices = []
    tensor_inputs = []
    tensor_outputs = []
    for i, arg in enumerate(args):
      if torch.is_tensor(arg):
        tensor_inputs.append(arg)
        ctx.tensor_indices.append(i)
        ctx.inputs.append(None)
      else:
        ctx.inputs.append(arg)

    ctx.save_for_backward(*tensor_inputs)

    with torch.no_grad():
      outputs = run_function(*args)

    return outputs
示例#6
0
 def forward(ctx, structure, block_size, body_fn, *state):
     ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*state)
     with torch.enable_grad():
         ctx.devices = [s.device for s in state]
         cpu_state = nest.map_structure(
             lambda x: x.to('cpu', non_blocking=True), state)
     ctx.save_for_backward(*cpu_state)
     ctx.structure = structure
     ctx.block_size = block_size
     ctx.body_fn = body_fn
     state = nest.pack_sequence_as(ctx.structure, state)
     ctx.fwd_cpu_state = torch.get_rng_state()
     with torch.no_grad():
         for _ in range(block_size):
             state = body_fn(state)
     state = nest.flatten(state)
     return tuple(state)
示例#7
0
 def forward(ctx, run_function, preserve_rng_state, *args):
     check_backward_validity(args)
     ctx.run_function = run_function
     ctx.preserve_rng_state = preserve_rng_state
     ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
     if preserve_rng_state:
         ctx.fwd_cpu_state = torch.get_rng_state()
         # Don't eagerly initialize the cuda context by accident.
         # (If the user intends that the context is initialized later, within their
         # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
         # we have no way to anticipate this will happen before we run the function.)
         ctx.had_cuda_in_fwd = False
         if torch.cuda._initialized:
             ctx.had_cuda_in_fwd = True
             ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
     ctx.save_for_backward(*args)
     with torch.no_grad():
         outputs = run_function(*args)
     return outputs
示例#8
0
    def forward(ctx, run_function, callback, preserve_rng_state, *args):
        print('Fwd', args)
        chkpt.check_backward_validity(args)
        ctx.callback = callback
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = chkpt.get_device_states(
                    *args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():
            outputs = run_function(*args)
        return outputs
示例#9
0
 def record_rng(self, *args):
     self.cpu_state = torch.get_rng_state()
     if torch.cuda._initialized:
         self.cuda_in_fwd = True
         self.gpu_devices, self.gpu_states = get_device_states(*args)
示例#10
0
 def __init__(self, *tensors):
     self.fwd_cpu_state = torch.get_rng_state()
     self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
示例#11
0
 def record_rng(self, *args):
     # rng refers to "random number generator". To reproduce the random initialization as same as the recorded.
     self.cpu_state = torch.get_rng_state()
     if torch.cuda._initialized:
         self.cuda_in_fwd = True
         self.gpu_devices, self.gpu_states = get_device_states(*args)