Example #1
0
def _affine_bounding_box_xyxy(
    bounding_box: torch.Tensor,
    image_size: Tuple[int, int],
    angle: float,
    translate: Optional[List[float]] = None,
    scale: Optional[float] = None,
    shear: Optional[List[float]] = None,
    center: Optional[List[float]] = None,
    expand: bool = False,
) -> torch.Tensor:
    dtype = bounding_box.dtype if torch.is_floating_point(
        bounding_box) else torch.float32
    device = bounding_box.device

    if translate is None:
        translate = [0.0, 0.0]

    if scale is None:
        scale = 1.0

    if shear is None:
        shear = [0.0, 0.0]

    if center is None:
        height, width = image_size
        center_f = [width * 0.5, height * 0.5]
    else:
        center_f = [float(c) for c in center]

    translate_f = [float(t) for t in translate]
    affine_matrix = torch.tensor(
        _get_inverse_affine_matrix(center_f,
                                   angle,
                                   translate_f,
                                   scale,
                                   shear,
                                   inverted=False),
        dtype=dtype,
        device=device,
    ).view(2, 3)
    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
    points = torch.cat(
        [points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using affine matrix
    transformed_points = torch.matmul(points, affine_matrix.T)
    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
    transformed_points = transformed_points.view(-1, 4, 2)
    out_bbox_mins, _ = torch.min(transformed_points, dim=1)
    out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)

    if expand:
        # Compute minimum point for transformed image frame:
        # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
        height, width = image_size
        points = torch.tensor(
            [
                [0.0, 0.0, 1.0],
                [0.0, 1.0 * height, 1.0],
                [1.0 * width, 1.0 * height, 1.0],
                [1.0 * width, 0.0, 1.0],
            ],
            dtype=dtype,
            device=device,
        )
        new_points = torch.matmul(points, affine_matrix.T)
        tr, _ = torch.min(new_points, dim=0, keepdim=True)
        # Translate bounding boxes
        out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0]
        out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1]

    return out_bboxes
Example #2
0
def perspective_bounding_box(
    bounding_box: torch.Tensor,
    format: features.BoundingBoxFormat,
    perspective_coeffs: List[float],
) -> torch.Tensor:

    if len(perspective_coeffs) != 8:
        raise ValueError(
            "Argument perspective_coeffs should have 8 float values")

    original_shape = bounding_box.shape
    bounding_box = convert_bounding_box_format(
        bounding_box,
        old_format=format,
        new_format=features.BoundingBoxFormat.XYXY).view(-1, 4)

    dtype = bounding_box.dtype if torch.is_floating_point(
        bounding_box) else torch.float32
    device = bounding_box.device

    # perspective_coeffs are computed as endpoint -> start point
    # We have to invert perspective_coeffs for bboxes:
    # (x, y) - end point and (x_out, y_out) - start point
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
    # and we would like to get:
    # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
    #       / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
    # and compute inv_coeffs in terms of coeffs

    denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[
        1] * perspective_coeffs[3]
    if denom == 0:
        raise RuntimeError(
            f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
            f"Denominator is zero, denom={denom}")

    inv_coeffs = [
        (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7])
        / denom,
        (-perspective_coeffs[1] +
         perspective_coeffs[2] * perspective_coeffs[7]) / denom,
        (perspective_coeffs[1] * perspective_coeffs[5] -
         perspective_coeffs[2] * perspective_coeffs[4]) / denom,
        (-perspective_coeffs[3] +
         perspective_coeffs[5] * perspective_coeffs[6]) / denom,
        (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6])
        / denom,
        (-perspective_coeffs[0] * perspective_coeffs[5] +
         perspective_coeffs[2] * perspective_coeffs[3]) / denom,
        (-perspective_coeffs[4] * perspective_coeffs[6] +
         perspective_coeffs[3] * perspective_coeffs[7]) / denom,
        (-perspective_coeffs[0] * perspective_coeffs[7] +
         perspective_coeffs[1] * perspective_coeffs[6]) / denom,
    ]

    theta1 = torch.tensor(
        [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]],
         [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
        dtype=dtype,
        device=device,
    )

    theta2 = torch.tensor([[inv_coeffs[6], inv_coeffs[7], 1.0],
                           [inv_coeffs[6], inv_coeffs[7], 1.0]],
                          dtype=dtype,
                          device=device)

    # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
    # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
    # Single point structure is similar to
    # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
    points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
    points = torch.cat(
        [points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
    # 2) Now let's transform the points using perspective matrices
    #   x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
    #   y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)

    numer_points = torch.matmul(points, theta1.T)
    denom_points = torch.matmul(points, theta2.T)
    transformed_points = numer_points / denom_points

    # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
    # and compute bounding box from 4 transformed points:
    transformed_points = transformed_points.view(-1, 4, 2)
    out_bbox_mins, _ = torch.min(transformed_points, dim=1)
    out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
    out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)

    # out_bboxes should be of shape [N boxes, 4]

    return convert_bounding_box_format(
        out_bboxes,
        old_format=features.BoundingBoxFormat.XYXY,
        new_format=format,
        copy=False).view(original_shape)
Example #3
0
 def is_floating_point(self) -> bool:
     value = self.storage.value()
     return torch.is_floating_point(value) if value is not None else True
Example #4
0
def to_float(b):
    "Recursively map lists of int tensors in `b ` to float."
    return apply(lambda x: x.float() if torch.is_floating_point(x) else x, b)
Example #5
0
def histogram(
        image: torch.Tensor,
        nbins: Optional[int] = 256,
        source_range: Optional[str] = 'image',
        normalize: Optional[bool] = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Return histogram of image.

    Unlike `numpy.histogram`, this function returns the centers of bins and
    does not rebin integer arrays. For integer arrays, each integer value has
    its own bin, which improves speed and intensity-resolution.
    The histogram is computed on the flattened image: for color images, the
    function should be used separately on each channel to obtain a histogram
    for each color channel.

    Parameters
    ----------
    image : torch.Tensor
        Input image
    nbins : Optional[int], default 256
        Number of bins used to calculate histogram. This value is ignored for
        integer arrays.
    source_range : Optional[str], default 'image'
        'image' (default) determines the range from the input image.
        'dtype' determines the range from the expected range of the images
        of that data type.
    normalize : Optional[bool], default False
        If True, normalize the histogram by the sum of its values.

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor]
        hist: The values of the histogram
        bin_centers: The values at the center of bins.

    Notes
    -----
    cumulative_distribution

    Examples
    --------
    >>> from skimage import data, exposure, img_as_float
    >>> import torch
    >>> image = img_as_float(data.camera())
    >>> np.histogram(image, bins=2)
    (array([107432, 154712]), array([ 0. ,  0.5,  1. ]))
    >>> image = torch.tensor(img_as_float(data.camera()))
    >>> exposure.histogram(image, nbins=2)
    (tensor([107432, 154712]), tensor([ 0.2500,  0.7500]))
    """
    if not isinstance(nbins, int):
        raise ValueError("Given bin cannot be non integer type")

    shape = image.size()

    if len(shape) == 3 and shape[0] < 4:
        warnings.warn("""This might be a color image. The histogram will be
             computed on the flattened image. You can instead
             apply this function to each color channel.""")

    image = image.flatten()
    min_v = torch.min(image).item()
    max_v = torch.max(image).item()

    # if the input image is normal integer type
    # like gray scale from 0-255, we implement fast histogram calculation
    # by returning bin count for each pixel value
    if not torch.is_floating_point(image):
        hist, bin_centers = _bin_count_histogram(image, source_range)
    else:
        if source_range == 'image':
            hist = torch.histc(image, nbins, min=min_v, max=max_v)
            bin_centers = _calc_bin_centers(min_v, max_v, nbins)
        elif source_range == 'dtype':
            min_v, max_v = dtype_limits(image, clip_negative=False)
            hist = torch.histc(image, nbins, min=min_v, max=max_v)
            bin_centers = _calc_bin_centers(min_v, max_v, nbins)
        else:
            raise ValueError("Wrong value for the `source_range` argument")

    if normalize:
        hist = torch.div(hist, float(torch.sum(hist).item()))
        return (hist, bin_centers)

    return (hist.long(), bin_centers)
Example #6
0
def run_classification_explanation(root,
                                   dataset_name,
                                   split_name,
                                   model,
                                   batch,
                                   datasets_infos,
                                   nb_samples,
                                   algorithm_name,
                                   algorithm_fn,
                                   output_name,
                                   algorithm_kwargs=None,
                                   nb_explanations=1,
                                   epoch=None,
                                   average_filters=True):
    """
    Run an explanation of a classification output
    """

    # do sample by sample to simplify the export procedure
    for n in range(nb_samples):
        logger.info('sample={}'.format(n))
        batch_n = get_batch_n(batch,
                              len_batch(batch),
                              np.asarray([n]),
                              transforms=None,
                              use_advanced_indexing=True)

        for tensor in batch_n.values():
            if isinstance(tensor,
                          torch.Tensor) and torch.is_floating_point(tensor):
                # we want to back-propagate up to the inputs
                tensor.requires_grad = True

        try:
            with torch.no_grad():
                outputs = model(batch_n)
                output = outputs.get(output_name)
                assert output is not None
                output_np = to_value(output.output)[0]
                max_class_indices = (-output_np).argsort()[0:nb_explanations]
        except Exception as e:
            logger.error(
                'exception, aborted `run_classification_explanation`=', e)
            continue

        # make sure the model is not contaminated by uncleaned hooks
        r = None
        with utilities.CleanAddedHooks(model) as context:
            algorithm_instance = algorithm_fn(model=model, **algorithm_kwargs)
            r = algorithm_instance(inputs=batch_n,
                                   target_class_name=output_name,
                                   target_class=max_class_indices[0])

        if r is None:
            # the algorithm failed, go to the next one
            return

        selected_output_name, cams_dict = r
        assert nb_explanations == 1, 'TODO handle for multiple explanations!'

        for input_name, g in cams_dict.items():
            if g is None:
                # discard this input!
                continue

            enumerate_i = 0
            c = max_class_indices[enumerate_i]  # the class output
            c_name = fill_class_name(output, c, datasets_infos, dataset_name,
                                     split_name)

            filename = 'sample-{}-output-{}-epoch-{}-rank-{}-alg-{}-explanation_for-{}'.format(
                n, input_name, epoch, enumerate_i, algorithm_name, c_name)
            filename = utilities.safe_filename(filename)
            export_path = os.path.join(root, filename)

            def format_image(g):
                if not isinstance(g, (np.ndarray, torch.Tensor)):
                    return g
                if average_filters and len(g.shape) >= 3:
                    return np.reshape(np.average(np.abs(g), axis=1),
                                      [g.shape[0], 1] + list(g.shape[2:]))
                return g

            with open(export_path + '.txt', 'w') as f:
                if isinstance(g, collections.Mapping):
                    # handle multiple explanation outputs
                    for name, value in g.items():
                        batch_n['explanation_{}'.format(name)] = format_image(
                            value)
                else:
                    # default: single tensor
                    batch_n['explanation'] = format_image(g)
                batch_n['output_found'] = str(output_np)
                batch_n['output_name_found'] = c_name

                #positive, negative = guided_back_propagation.GuidedBackprop.get_positive_negative_saliency(g)
                #batch_n['explanation_positive'] = positive
                #batch_n['explanation_negative'] = negative
                #f.write('gradient average positive={}\n'.format(np.average(g[np.where(g > 0)])))
                #f.write('gradient average negative={}\n'.format(np.average(g[np.where(g < 0)])))
                sample_export.export_sample(batch_n, 0, export_path + '-', f)
Example #7
0
 def tensor_general_ops(self):
     a = torch.randn(4)
     b = torch.tensor([1.5])
     x = torch.ones((2, ))
     c = torch.randn(4, dtype=torch.cfloat)
     w = torch.rand(4, 4, 4, 4)
     v = torch.rand(4, 4, 4, 4)
     return len(
         # torch.is_tensor(a),
         # torch.is_storage(a),
         torch.is_complex(a),
         torch.is_conj(a),
         torch.is_floating_point(a),
         torch.is_nonzero(b),
         # torch.set_default_dtype(torch.float32),
         # torch.get_default_dtype(),
         # torch.set_default_tensor_type(torch.DoubleTensor),
         torch.numel(a),
         # torch.set_printoptions(),
         # torch.set_flush_denormal(False),
         # https://pytorch.org/docs/stable/tensors.html#tensor-class-reference
         # x.new_tensor([[0, 1], [2, 3]]),
         x.new_full((3, 4), 3.141592),
         x.new_empty((2, 3)),
         x.new_ones((2, 3)),
         x.new_zeros((2, 3)),
         x.is_cuda,
         x.is_quantized,
         x.is_meta,
         x.device,
         x.dim(),
         c.real,
         c.imag,
         # x.backward(),
         x.clone(),
         w.contiguous(),
         w.contiguous(memory_format=torch.channels_last),
         w.copy_(v),
         w.copy_(1),
         w.copy_(0.5),
         x.cpu(),
         # x.cuda(),
         # x.data_ptr(),
         x.dense_dim(),
         w.fill_diagonal_(0),
         w.element_size(),
         w.exponential_(),
         w.fill_(0),
         w.geometric_(0.5),
         a.index_fill(0, torch.tensor([0, 2]), 1),
         a.index_put_([torch.argmax(a)], torch.tensor(1.0)),
         a.index_put([torch.argmax(a)], torch.tensor(1.0)),
         w.is_contiguous(),
         c.is_complex(),
         w.is_conj(),
         w.is_floating_point(),
         w.is_leaf,
         w.is_pinned(),
         w.is_set_to(w),
         # w.is_shared,
         w.is_coalesced(),
         w.coalesce(),
         w.is_signed(),
         w.is_sparse,
         torch.tensor([1]).item(),
         x.log_normal_(),
         # x.masked_scatter_(),
         # x.masked_scatter(),
         # w.normal(),
         w.numel(),
         # w.pin_memory(),
         # w.put_(0, torch.tensor([0, 1], w)),
         x.repeat(4, 2),
         a.clamp_(0),
         a.clamp(0),
         a.clamp_min(0),
         a.hardsigmoid_(),
         a.hardsigmoid(),
         a.hardswish_(),
         a.hardswish(),
         a.hardtanh_(),
         a.hardtanh(),
         a.leaky_relu_(),
         a.leaky_relu(),
         a.relu_(),
         a.relu(),
         a.resize_as_(a),
         a.type_as(a),
         a._shape_as_tensor(),
         a.requires_grad_(False),
     )
Example #8
0
def upsample(tensor: TensorNCX,
             size: ShapeX,
             mode: Literal['linear', 'nearest'] = 'linear') -> TensorNCX:
    """
    Upsample a 1D, 2D, 3D tensor

    This is a wrapper around `torch.nn.Upsample` to make it more practical. Support integer based tensors.

    Note:
        PyTorch as of version 1.3 doesn't support non-floating point upsampling
        (see https://github.com/pytorch/pytorch/issues/13218 and https://github.com/pytorch/pytorch/issues/5580).
        Instead use a workaround (TODO assess the speed impact!).


    Args:
        tensor: 1D (shape = b x c x n), 2D (shape = b x c x h x w) or 3D (shape = b x c x d x h x w)
        size: if 1D, shape = n, if 2D shape = h x w, if 3D shape = d x h x w
        mode: `linear` or `nearest`

    Returns:
        an up-sampled tensor with same batch size and filter size as the input
    """

    assert len(size) + 2 == len(tensor.shape), 'shape must be only the resampled components, ' \
                                               'WITHOUT the batch and filter channels'
    assert len(
        tensor.shape) >= 3, 'only 1D, 2D, 3D tensors are currently handled!'
    assert len(
        tensor.shape) <= 5, 'only 1D, 2D, 3D tensors are currently handled!'

    size = tuple(size)
    if not torch.is_floating_point(tensor):
        # Workaround for non floating point tensors. Ignore `mode`
        if len(tensor.shape) == 3:
            return _upsample_int_1d(tensor, size)
        elif len(tensor.shape) == 4:
            return _upsample_int_2d(tensor, size)
        elif len(tensor.shape) == 5:
            return _upsample_int_3d(tensor, size)
        else:
            raise NotImplementedError('dimension not implemented!')

    if mode == 'linear':
        align_corners = True
        if len(tensor.shape) == 4:
            # 2D case
            return nn.Upsample(mode='bilinear',
                               size=size,
                               align_corners=align_corners).forward(tensor)
        elif len(tensor.shape) == 5:
            # 3D case
            return nn.Upsample(mode='trilinear',
                               size=size,
                               align_corners=align_corners).forward(tensor)
        elif len(tensor.shape) == 3:
            # 1D case
            return nn.Upsample(mode='linear',
                               size=size,
                               align_corners=align_corners).forward(tensor)
        else:
            assert 0, 'impossible or bug!'

    elif mode == 'nearest':
        return nn.Upsample(mode='nearest', size=size).forward(tensor)
    else:
        assert 0, 'upsample mode ({}) is not handled'.format(mode)
Example #9
0
def interpolate_bilinear_2d_like_tensorflow1x(input,
                                              size=None,
                                              scale_factor=None,
                                              align_corners=None,
                                              method='slow'):
    r"""Down/up samples the input to either the given :attr:`size` or the given :attr:`scale_factor`

    Epsilon-exact bilinear interpolation as it is implemented in TensorFlow 1.x:
    https://github.com/tensorflow/tensorflow/blob/f66daa493e7383052b2b44def2933f61faf196e0/tensorflow/core/kernels/image_resizer_state.h#L41
    https://github.com/tensorflow/tensorflow/blob/6795a8c3a3678fb805b6a8ba806af77ddfe61628/tensorflow/core/kernels/resize_bilinear_op.cc#L85
    as per proposal:
    https://github.com/pytorch/pytorch/issues/10604#issuecomment-465783319

    Related materials:
    https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35
    https://jricheimer.github.io/tensorflow/2019/02/11/resize-confusion/
    https://machinethink.net/blog/coreml-upsampling/

    Currently only 2D spatial sampling is supported, i.e. expected inputs are 4-D in shape.

    The input dimensions are interpreted in the form:
    `mini-batch x channels x height x width`.

    Args:
        input (Tensor): the input tensor
        size (Tuple[int, int]): output spatial size.
        scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
        align_corners (bool, optional): Same meaning as in TensorFlow 1.x.
        method (str, optional):
            'slow' (1e-4 L_inf error on GPU, bit-exact on CPU, with checkerboard 32x32->299x299), or
            'fast' (1e-3 L_inf error on GPU and CPU, with checkerboard 32x32->299x299)
    """
    if method not in ('slow', 'fast'):
        raise ValueError('how_exact can only be one of "slow", "fast"')

    if input.dim() != 4:
        raise ValueError('input must be a 4-D tensor')

    if not torch.is_floating_point(input):
        raise ValueError('input must be of floating point dtype')

    if size is not None and (type(size) not in (tuple, list)
                             or len(size) != 2):
        raise ValueError('size must be a list or a tuple of two elements')

    if align_corners is None:
        raise ValueError(
            'align_corners is not specified (use this function for a complete determinism)'
        )

    def _check_size_scale_factor(dim):
        if size is None and scale_factor is None:
            raise ValueError('either size or scale_factor should be defined')
        if size is not None and scale_factor is not None:
            raise ValueError(
                'only one of size or scale_factor should be defined')
        if scale_factor is not None and isinstance(
                scale_factor, tuple) and len(scale_factor) != dim:
            raise ValueError('scale_factor shape must match input shape. '
                             'Input is {}D, scale_factor size is {}'.format(
                                 dim, len(scale_factor)))

    is_tracing = torch._C._get_tracing_state()

    def _output_size(dim):
        _check_size_scale_factor(dim)
        if size is not None:
            if is_tracing:
                return [torch.tensor(i) for i in size]
            else:
                return size
        scale_factors = _ntuple(dim)(scale_factor)
        # math.floor might return float in py2.7

        # make scale_factor a tensor in tracing so constant doesn't get baked in
        if is_tracing:
            return [(torch.floor(
                (input.size(i + 2).float() *
                 torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
                    for i in range(dim)]
        else:
            return [
                int(math.floor(float(input.size(i + 2)) * scale_factors[i]))
                for i in range(dim)
            ]

    def tf_calculate_resize_scale(in_size, out_size):
        if align_corners:
            if is_tracing:
                return (in_size - 1) / (out_size.float() - 1).clamp(min=1)
            else:
                return (in_size - 1) / max(1, out_size - 1)
        else:
            if is_tracing:
                return in_size / out_size.float()
            else:
                return in_size / out_size

    out_size = _output_size(2)
    scale_x = tf_calculate_resize_scale(input.shape[3], out_size[1])
    scale_y = tf_calculate_resize_scale(input.shape[2], out_size[0])

    def resample_using_grid_sample():
        grid_x = torch.arange(0,
                              out_size[1],
                              1,
                              dtype=input.dtype,
                              device=input.device)
        grid_x = grid_x * (2 * scale_x / (input.shape[3] - 1)) - 1

        grid_y = torch.arange(0,
                              out_size[0],
                              1,
                              dtype=input.dtype,
                              device=input.device)
        grid_y = grid_y * (2 * scale_y / (input.shape[2] - 1)) - 1

        grid_x = grid_x.view(1, out_size[1]).repeat(out_size[0], 1)
        grid_y = grid_y.view(out_size[0], 1).repeat(1, out_size[1])

        grid_xy = torch.cat((grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)),
                            dim=2).unsqueeze(0)
        grid_xy = grid_xy.repeat(input.shape[0], 1, 1, 1)

        out = F.grid_sample(input,
                            grid_xy,
                            mode='bilinear',
                            padding_mode='border',
                            align_corners=True)
        return out

    def resample_manually():
        grid_x = torch.arange(0,
                              out_size[1],
                              1,
                              dtype=input.dtype,
                              device=input.device)
        grid_x = grid_x * torch.tensor(scale_x, dtype=torch.float32)
        grid_x_lo = grid_x.long()
        grid_x_hi = (grid_x_lo + 1).clamp_max(input.shape[3] - 1)
        grid_dx = grid_x - grid_x_lo.float()

        grid_y = torch.arange(0,
                              out_size[0],
                              1,
                              dtype=input.dtype,
                              device=input.device)
        grid_y = grid_y * torch.tensor(scale_y, dtype=torch.float32)
        grid_y_lo = grid_y.long()
        grid_y_hi = (grid_y_lo + 1).clamp_max(input.shape[2] - 1)
        grid_dy = grid_y - grid_y_lo.float()

        # could be improved with index_select
        in_00 = input[:, :, grid_y_lo, :][:, :, :, grid_x_lo]
        in_01 = input[:, :, grid_y_lo, :][:, :, :, grid_x_hi]
        in_10 = input[:, :, grid_y_hi, :][:, :, :, grid_x_lo]
        in_11 = input[:, :, grid_y_hi, :][:, :, :, grid_x_hi]

        in_0 = in_00 + (in_01 - in_00) * grid_dx.view(1, 1, 1, out_size[1])
        in_1 = in_10 + (in_11 - in_10) * grid_dx.view(1, 1, 1, out_size[1])
        out = in_0 + (in_1 - in_0) * grid_dy.view(1, 1, out_size[0], 1)

        return out

    if method == 'slow':
        out = resample_manually()
    else:
        out = resample_using_grid_sample()

    return out
Example #10
0
def basic_operation():
    # REF [site] >>
    #	https://pytorch.org/docs/stable/tensors.html
    #	https://pytorch.org/docs/stable/tensor_attributes.html

    x = torch.empty(5, 3)
    print('x =', x)
    print('x.shape = {}, x.dtype = {}.'.format(x.shape, x.dtype))
    #print('x =', x.data)

    x = torch.rand(2, 3)
    print('x =', x)

    x = torch.randn(2, 3)
    print('x =', x)

    x = torch.randn(2, 3)
    print('x =', x)

    x = torch.randperm(5)
    print('x =', x)

    x = torch.FloatTensor(10, 12, 3, 3)
    print('x =', x.size())
    print('x =', x.size()[:])

    #--------------------
    y = torch.zeros(2, 3)
    print('y =', y)

    y = torch.ones(2, 3)
    print('y =', y)

    y = torch.arange(0, 3, step=0.5)
    print('y =', y)

    x = torch.tensor(1, dtype=torch.int32)
    #x = torch.tensor(1, dtype=torch.int32, device='cuda:1')
    print('x =', x)

    x = torch.tensor([5.5, 3])
    print('x =', x)

    x = x.new_ones(5, 3, dtype=torch.double)  # new_* methods take in sizes.
    print('x =', x)
    x = torch.randn_like(x, dtype=torch.float)  # Override dtype.
    print('x =', x)

    #--------------------
    y = torch.rand(5, 3)
    print('x + y =', x + y)

    print('x + y =', torch.add(x, y))

    result = torch.empty(5, 3)
    torch.add(x, y, out=result)
    print('x + y =', result)

    #--------------------
    # Any operation that mutates a tensor in-place is post-fixed with an _.
    # For example: x.copy_(y), x.t_(), will change x.

    y.add_(x)  # In-place.
    print('y =', y)

    #--------------------
    # You can use standard NumPy-like indexing with all bells and whistles!
    print(x[:, 1])

    #--------------------
    # If you have a one element tensor, use .item() to get the value as a Python number.
    x = torch.randn(1)
    print('x =', x)
    print('x.item() =', x.item())

    #--------------------
    x = torch.randn(2, 2)
    print('x.is_cuda =', x.is_cuda)
    print('x.is_complex() =', x.is_complex())
    print('x.is_contiguous() =', x.is_contiguous())
    print('x.is_distributed() =', x.is_distributed())
    print('x.is_floating_point() =', x.is_floating_point())
    print('x.is_pinned() =', x.is_pinned())
    print('x.is_quantized =', x.is_quantized)
    print('x.is_shared() =', x.is_shared())
    print('x.is_signed() =', x.is_signed())
    print('x.is_sparse =', x.is_sparse)

    print('x.contiguous() =', x.contiguous())
    print('x.storage() =', x.storage())

    #--------------------
    x = torch.randn(2, 2)
    print('torch.is_tensor(x) =', torch.is_tensor(x))
    print('torch.is_storage(x) =', torch.is_storage(x))
    print('torch.is_complex(x) =', torch.is_complex(x))
    print('torch.is_floating_point(x) =', torch.is_floating_point(x))

    # Sets the default floating point dtype to d.
    # This type will be used as default floating point type for type inference in torch.tensor().
    torch.set_default_dtype(torch.float32)
    print('torch.get_default_dtype() =', torch.get_default_dtype())
    # Sets the default torch.Tensor type to floating point tensor type.
    # This type will also be used as default floating point type for type inference in torch.tensor().
    torch.set_default_tensor_type(torch.FloatTensor)

    #--------------------
    # REF [site] >> https://pytorch.org/docs/stable/tensor_view.html
    # View tensor shares the same underlying data with its base tensor.
    # Supporting View avoids explicit data copy, thus allows us to do fast and memory efficient reshaping, slicing and element-wise operations.

    # If you want to resize/reshape tensor, you can use torch.view.
    x = torch.randn(4, 4)
    y = x.view(16)
    z = x.view(-1, 8)  # The size -1 is inferred from other dimensions.
    print('x.size() = {}, y.size() = {}, z.size() = {}.'.format(
        x.size(), y.size(), z.size()))

    t = torch.rand(4, 4)
    b = t.view(2, 8)
    print('t.storage().data_ptr() == b.storage().data_ptr()?',
          t.storage().data_ptr() == b.storage().data_ptr())
Example #11
0
    def get_p_cond_dur_chFlat_vectorized(
        p_dim_cond_td_chDim: torch.Tensor,
        unabs_dim_td_cond_chDim: torch.Tensor,
        dur_buffer_fr: torch.Tensor,
        dur_stim_frs: torch.Tensor,
        p1st_dim0: torch.Tensor,
    ) -> torch.Tensor:

        if torch.is_floating_point(dur_buffer_fr):
            bufs = torch.cat([
                dur_buffer_fr.floor().long().reshape([1]),
                dur_buffer_fr.floor().long().reshape([1]) + 1
            ], 0)
            prop_buf = torch.tensor(1.) - torch.abs(dur_buffer_fr - bufs)
            ps = []
            for buf in bufs:
                ps.append(
                    Dtb2DVDBufSerial.get_p_cond_dur_chFlat(
                        p_dim_cond_td_chDim,
                        unabs_dim_td_cond_chDim,
                        buf.long(),
                        dur_stim_frs=dur_stim_frs,
                        p1st_dim0=p1st_dim0))
            ps = torch.stack(ps)
            p_cond_dur_chFlat = (ps * prop_buf[:, None, None, None]).sum(0)
            return p_cond_dur_chFlat

        # vectorized version
        p1st_dim = torch.stack([p1st_dim0, torch.tensor(1.) - p1st_dim0])

        n_cond = p_dim_cond_td_chDim.shape[1]
        ndur = len(dur_stim_frs)
        p_cond_dur_chFlat = torch.zeros([n_cond, ndur, consts.N_CH_FLAT])

        p_cond_dim_td_chDim = p_dim_cond_td_chDim.transpose(0, 1)
        unabs_dim_cond_td_chDim = unabs_dim_td_cond_chDim.transpose(1, 2)
        unabs_cond_dim_td_chDim = unabs_dim_cond_td_chDim.transpose(0, 1)
        ichs = torch.arange(consts.N_CH_FLAT)
        dim1sts = torch.arange(consts.N_DIM)

        for idur, dur_stim in enumerate(dur_stim_frs):
            p0 = torch.zeros([n_cond, consts.N_CH_FLAT])

            for td1st in torch.arange(dur_stim):
                max_td2nd = dur_stim - max([td1st - dur_buffer_fr, 0])
                td2nds = torch.arange(max_td2nd)

                dim1st, td2nd, ich = torch.meshgrid([dim1sts, td2nds, ichs])
                dim2nd = consts.get_odim(dim1st)
                ch1st = consts.CHS_TENSOR[dim1st, ich]
                ch2nd = consts.CHS_TENSOR[dim2nd, ich]

                # When both dim1st and dim2nd are absorbed
                p0 = p0 + (
                    (p_cond_dim_td_chDim[:, dim1st,
                                         td1st.expand_as(td2nd), ch1st] *
                     p_cond_dim_td_chDim[:, dim2nd, td2nd, ch2nd]).sum(
                         -2)  # sum across td2nd
                    * p1st_dim[None, :, None]).sum(1)  # sum across p1st

                # When only dim1st is absorbed,
                t2nd = max_td2nd

                dim1st, ich = torch.meshgrid([dim1sts, ichs])
                dim2nd = consts.get_odim(dim1st)
                ch1st = consts.CHS_TENSOR[dim1st, ich]
                ch2nd = consts.CHS_TENSOR[dim2nd, ich]

                p0 = p0 + ((p_cond_dim_td_chDim[:, dim1st, td1st, ch1st] *
                            unabs_cond_dim_td_chDim[:, dim2nd, t2nd, ch2nd]) *
                           p1st_dim[None, :, None]).sum(1)  # sum across dim1st

            dim1st, ich = torch.meshgrid([dim1sts, ichs])
            dim2nd = consts.get_odim(dim1st)

            ch1st = consts.CHS_TENSOR[dim1st, ich]
            ch2nd = consts.CHS_TENSOR[dim2nd, ich]

            # When neither dim is absorbed,
            # then dim2nd is certainly not absorbed,
            # and stays at the state at t = min([dur_stim, dur_buffer_fr])
            t2nd = min([dur_stim, dur_buffer_fr])
            p0 = p0 + (unabs_cond_dim_td_chDim[:, dim1st, dur_stim, ch1st] *
                       unabs_cond_dim_td_chDim[:, dim2nd, t2nd, ch2nd] *
                       p1st_dim[None, :, None]).sum(1)

            p_cond_dur_chFlat[:,
                              idur, :] = (p_cond_dur_chFlat[:, idur, :] + p0)
        return p_cond_dur_chFlat
Example #12
0
    def get_p_cond_dur_chFlat(
        p_dim_cond_td_chDim: torch.Tensor,
        unabs_dim_td_cond_chDim: torch.Tensor,
        dur_buffer_fr: torch.Tensor,
        dur_stim_frs: torch.Tensor,
        p1st_dim0: torch.Tensor,
    ) -> torch.Tensor:
        """

        :param p_dim_cond_td_chDim: [dim, cond, td, chDim]
        :param unabs_dim_td_cond_chDim: [dim, td, cond, chDim]
        :param dur_buffer_fr: scalar
        :param dur_stim_frs: [idur]
        :return: p_cond_dur_chFlat[cond, dur, chFlat]
        """
        if torch.is_floating_point(dur_buffer_fr):
            bufs = torch.cat([
                dur_buffer_fr.floor().long().reshape([1]),
                dur_buffer_fr.floor().long().reshape([1]) + 1
            ], 0)
            prop_buf = torch.tensor(1.) - torch.abs(dur_buffer_fr - bufs)
            ps = []
            for buf in bufs:
                ps.append(
                    Dtb2DVDBufSerial.get_p_cond_dur_chFlat(
                        p_dim_cond_td_chDim,
                        unabs_dim_td_cond_chDim,
                        buf.long(),
                        dur_stim_frs=dur_stim_frs,
                        p1st_dim0=p1st_dim0))
            ps = torch.stack(ps)
            p_cond_dur_chFlat = (ps * prop_buf[:, None, None, None]).sum(0)
            return p_cond_dur_chFlat

        p1st_dim = [p1st_dim0, torch.tensor(1.) - p1st_dim0]

        n_cond = p_dim_cond_td_chDim.shape[1]
        ndur = len(dur_stim_frs)
        p_cond_dur_chFlat = torch.zeros([n_cond, ndur, consts.N_CH_FLAT])

        cumP_dim_cond_td_chDim = p_dim_cond_td_chDim.cumsum(-2)

        for dim1st in range(consts.N_DIM):
            dim2nd = consts.get_odim(dim1st)
            for idur, dur_stim in enumerate(dur_stim_frs):
                p0 = torch.zeros([n_cond, consts.N_CH_FLAT])
                for ich, chs in enumerate(consts.CHS_TENSOR.T):
                    ch1st = chs[dim1st]
                    ch2nd = chs[dim2nd]

                    for td1st in torch.arange(dur_stim):
                        max_td2nd = dur_stim - max([td1st - dur_buffer_fr, 0])
                        # ==== When both dims are absorbed
                        p0[:, ich] = p0[:, ich] + (
                            p_dim_cond_td_chDim[dim1st, :, td1st, ch1st] *
                            cumP_dim_cond_td_chDim[dim2nd, :, max_td2nd,
                                                   ch2nd])

                        # ==== When only dim1st is absorbed
                        p0[:, ich] = p0[:, ich] + (
                            p_dim_cond_td_chDim[dim1st, :, td1st, ch1st] *
                            unabs_dim_td_cond_chDim[dim2nd, max_td2nd, :,
                                                    ch2nd])
                    # ==== When dim1st is not absorbed
                    t1st = dur_stim
                    t2nd = dur_stim - max([t1st - dur_buffer_fr, 0])

                    # ==== When only dim2nd is absorbed: this can happen when
                    #   dim2nd is absorbed within the buffer duration
                    p0[:, ich] = p0[:, ich] + (
                        unabs_dim_td_cond_chDim[dim1st, t1st, :, ch1st] *
                        cumP_dim_cond_td_chDim[dim2nd, :, t2nd, ch2nd])

                    # ==== When neither dim is absorbed
                    p0[:, ich] = p0[:, ich] + (
                        unabs_dim_td_cond_chDim[dim1st, t1st, :, ch1st] *
                        unabs_dim_td_cond_chDim[dim2nd, t2nd, :, ch2nd])

                p0 = p0 / p0.sum(1, keepdim=True)
                p_cond_dur_chFlat[:,
                                  idur, :] = (p_cond_dur_chFlat[:, idur, :] +
                                              p1st_dim[dim1st] * p0)
        return p_cond_dur_chFlat
Example #13
0
 def is_floating_point(self) -> bool:
     return torch.is_floating_point(self.options())
Example #14
0
 def filter_only_float_tensors(t):
     return isinstance(
         t, tensor_quantity.TensorQuantity) and (torch.is_floating_point(t)
                                                 or torch.is_complex(t))
Example #15
0
def contrast(img, factor):
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    mean = torch.mean(rgb2gray(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
    return blend(mean, img, max(factor, 1e-6))
Example #16
0
    def optimize_agent(self, itr, samples):
        """
        Train the agent, for multiple epochs over minibatches taken from the
        input samples.  Organizes agent inputs from the training data, and
        moves them to device (e.g. GPU) up front, so that minibatches are
        formed within device, without further data transfer.
        """
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        return_, advantage, valid = self.process_returns(samples)
        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )
        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        T, B = samples.env.reward.shape[:2]
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        # If recurrent, use whole trajectories, only shuffle B; else shuffle
        # all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches
        for _ in range(self.epochs):
            # we apply different augmentations for each "epoch"
            aug_loss_inputs = self.augment_loss_inputs(loss_inputs)
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                # NOTE: if not recurrent, will lose leading T dim, should be
                # OK.
                if self.expert_batch_iter:
                    bc_batch_dict = next(self.expert_batch_iter)
                    bc_obs = self.augment_bc_obs(bc_batch_dict['obs'])
                    bc_acts = bc_batch_dict['acts']
                    assert not torch.is_floating_point(bc_acts), bc_acts
                    bc_acts = bc_acts.long()
                else:
                    bc_obs = None
                    bc_acts = None
                loss, entropy, perplexity = self.loss(*aug_loss_inputs[T_idxs,
                                                                       B_idxs],
                                                      bc_observations=bc_obs,
                                                      bc_actions=bc_acts,
                                                      init_rnn_state=rnn_state)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                opt_info.loss.append(loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
                self.update_counter += 1
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr - itr) \
                / self.n_itr

        return opt_info
Example #17
0
def to_half(b):
    "Recursively map lists of tensors in `b ` to FP16."
    return apply(lambda x: x.half() if torch.is_floating_point(x) else x, b)
Example #18
0
File: min.py Project: tchaton/tsd
def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=None):
    r"""
    |

    .. image:: https://raw.githubusercontent.com/rusty1s/pytsd/
            master/docs/source/_figures/min.svg?sanitize=true
        :align: center
        :width: 400px

    |

    Minimizes all values from the :attr:`src` tensor into :attr:`out` at the
    indices specified in the :attr:`index` tensor along a given axis
    :attr:`dim`.If multiple indices reference the same location, their
    **contributions minimize** (`cf.` :meth:`~tsd.scatter_add`).
    The second return tensor contains index location in :attr:`src` of each
    minimum value (known as argmin).

    For one-dimensional tensors, the operation computes

    .. math::
        \mathrm{out}_i = \min(\mathrm{out}_i, \min_j(\mathrm{src}_j))

    where :math:`\min_j` is over :math:`j` such that
    :math:`\mathrm{index}_j = i`.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim (int, optional): The axis along which to index.
            (default: :obj:`-1`)
        out (Tensor, optional): The destination tensor. (default: :obj:`None`)
        dim_size (int, optional): If :attr:`out` is not given, automatically
            create output with size :attr:`dim_size` at dimension :attr:`dim`.
            If :attr:`dim_size` is not given, a minimal sized output tensor is
            returned. (default: :obj:`None`)
        fill_value (int, optional): If :attr:`out` is not given, automatically
            fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
        fill_value (int, optional): If :attr:`out` is not given, automatically
            fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
            the output tensor is filled with the greatest possible value of
            :obj:`src.dtype`. (default: :obj:`None`)

    :rtype: (:class:`Tensor`, :class:`LongTensor`)

    .. testsetup::

        import torch

    .. testcode::

        from tsd import scatter_min

        src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
        index = torch.tensor([[ 4, 5,  4,  2,  3], [0,  0,  2,  2,  1]])
        out = src.new_zeros((2, 6))

        out, argmin = scatter_min(src, index, out=out)

        print(out)
        print(argmin)

    .. testoutput::

       tensor([[ 0.,  0., -4., -3., -2.,  0.],
               [-2., -4., -3.,  0.,  0.,  0.]])
       tensor([[-1, -1,  3,  4,  0,  1],
               [ 1,  4,  3, -1, -1, -1]])
    """
    if fill_value is None:
        op = torch.finfo if torch.is_floating_point(src) else torch.iinfo
        fill_value = op(src.dtype).max
    src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value)
    if src.size(dim) == 0:  # pragma: no cover
        return out, index.new_full(out.size(), -1)
    return ScatterMin.apply(out, src, index, dim)
Example #19
0
 def _convert_float_tensor(t: Tensor) -> Tensor:
     return t.to(to_type) if torch.is_floating_point(t) else t
Example #20
0
 def _to_float(t):
     return t.float() if torch.is_floating_point(t) else t
Example #21
0
    def get_floating_inputs_with_gradients(inputs):
        """
        Extract inputs that have a gradient

        Args:
            inputs: a tensor of dictionary of tensors

        Returns:
            Return a list of tuple (name, input) for the input that have a gradient
        """
        if isinstance(inputs, collections.Mapping):
            # if `i` is not a floating point, we can't calculate the gradient anyway...
            i = [(input_name, i) for input_name, i in inputs.items() if hasattr(i, 'grad') and torch.is_floating_point(i)]
        else:
            i = [('input', inputs)]
        return i
import torch

x = torch.rand(5, 3)
print(x)
print(torch.is_tensor(x))
print(torch.is_storage(x))
print(torch.is_floating_point(x))

print('')
print(torch.get_default_dtype())  # torch.float32
print(torch.tensor([1.2, 3]).dtype)  # default is torch.float32
torch.set_default_dtype(torch.float64)
print(torch.tensor([1.2, 3]).dtype)

print('')
torch.set_default_dtype(torch.float64)
print(torch.get_default_dtype())
torch.set_default_tensor_type(torch.FloatTensor)
print(torch.get_default_dtype())

print('')
x = torch.tensor([5, 3])
print(x)
print(torch.numel(x))
x = torch.rand(5, 3)
print(x)
print(torch.numel(x))

print('')
print(torch.set_flush_denormal(True))
print(torch.set_flush_denormal(False))
Example #23
0
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group_id, group in enumerate(self.param_groups):
            for param_id, p in enumerate(group["params"]):
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    dtype = torch.float16 if self.use_fp16_stats else p.data.dtype
                    # gradient momentums
                    state["exp_avg"] = torch.zeros_like(p.data,
                                                        dtype=dtype,
                                                        device="cpu")
                    # gradient variances
                    state["exp_avg_sq"] = torch.zeros_like(p.data,
                                                           dtype=dtype,
                                                           device="cpu")
                    if self.use_fp16_stats:
                        assert torch.is_floating_point(p.data)
                        state["exp_avg_scale"] = 1.0
                        state["exp_avg_sq_scale"] = 1.0

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

                p_data_bak = p.data  # backup of the original data pointer

                p.data = p.data.to(dtype=torch.float32, device="cpu")
                p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu")

                if self.use_fp16_stats:
                    exp_avg = exp_avg.float() * state["exp_avg_scale"]
                    exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"]

                state["step"] += 1
                beta1, beta2 = group["betas"]

                self.ds_opt_adam.adam_update(
                    self.opt_id,
                    state["step"],
                    group["lr"],
                    beta1,
                    beta2,
                    group["eps"],
                    group["weight_decay"],
                    group["bias_correction"],
                    p.data,
                    p.grad.data,
                    exp_avg,
                    exp_avg_sq,
                )

                if p_data_bak.data_ptr() != p.data.data_ptr():
                    p_data_bak.copy_(p.data)
                    p.data = p_data_bak

                if self.use_fp16_stats:

                    def inf_norm(t):
                        return torch.norm(t, float("inf"))

                    # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py
                    state["exp_avg_scale"], state["exp_avg_sq_scale"] = (
                        1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX,
                        1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX,
                    )
                    state["exp_avg"], state["exp_avg_sq"] = (
                        (exp_avg / state["exp_avg_scale"]).half(),
                        (exp_avg_sq / state["exp_avg_sq_scale"]).half(),
                    )

        return loss
Example #24
0
def _assert_floating(name, t):
    if not torch.is_floating_point(t):
        raise TypeError(
            '`{}` must be a floating point Tensor but is a {}'.format(
                name, t.type()))
Example #25
0
    def __call__(self, inputs, target_class_name, target_class=None):
        """
            Generate the guided back-propagation gradient

            Args:
                inputs: a tensor or dictionary of tensors. Must have `require_grads` for the inputs to be explained
                target_class: the index of the class to explain the decision. If `None`, the class output will be used
                target_class_name: the output node to be used. If `None`:
                    * if model output is a single tensor then use this as target output

                    * else it will use the first `OutputClassification` output

            Returns:
                a tuple (output_name, dictionary (input, integrated gradient))
            """
        logger.info('started integrated gradient ...')
        self.model.eval()  # make sure we are in eval mode
        input_names_with_gradient = dict(
            guided_back_propagation.GuidedBackprop.
            get_floating_inputs_with_gradients(inputs)).keys()
        if len(input_names_with_gradient) == 0:
            logger.error(
                'IntegratedGradients.__call__: failed. No inputs will collect gradient!'
            )
            return None
        else:
            logger.info('input_names_with_gradient={}'.format(
                input_names_with_gradient))

        outputs = self.model(inputs)
        model_output = outputs.get(target_class_name)
        if model_output is None:
            for output_name, output in outputs.items():
                if isinstance(output, outputs_trw.OutputClassification):
                    logger.info(
                        'IntegratedGradients.__call__: output found={}'.format(
                            output_name))
                    target_class_name = output_name
                    model_output = output
                    break
        if model_output is None:
            logger.error(
                'IntegratedGradients.__call__: failed. No suitable output could be found!'
            )
            return None
        model_output = self.post_process_output(model_output)

        if target_class is None:
            target_class = torch.argmax(model_output, dim=1)

        # construct our gradient target
        model_device = utilities.get_device(self.model, batch=inputs)
        nb_classes = model_output.shape[1]
        nb_samples = trw.utils.len_batch(inputs)

        if self.use_output_as_target:
            one_hot_output = model_output.clone()
        else:
            one_hot_output = torch.FloatTensor(
                nb_samples, nb_classes).to(device=model_device).zero_()
            one_hot_output[:, target_class] = 1.0

        # construct our reference inputs
        baseline_inputs = {}
        for feature_name, feature_value in inputs.items():
            if is_feature_metadata(feature_name, feature_value):
                # if metadata, we can't interpolate!
                continue
            baseline_inputs[feature_name] = torch.zeros_like(feature_value)

        # construct our integrated gradients
        integrated_gradients = {
            name: torch.zeros_like(inputs[name])
            for name in input_names_with_gradient
        }

        # start integration
        for n in range(self.steps):
            integrated_inputs = {}
            with torch.no_grad():
                # here do no propagate the gradient (mixture of input and baseline)
                # We just want the gradient for the `inputs`
                for feature_name, feature_value in inputs.items():
                    if is_feature_metadata(
                            feature_name, feature_value
                    ) or not torch.is_floating_point(feature_value):
                        # metadata or non floating point tensors: keep original value
                        integrated_inputs[feature_name] = feature_value
                    else:
                        baseline_value = baseline_inputs[feature_name]
                        integrated_inputs[
                            feature_name] = baseline_value + float(
                                n) / self.steps * (feature_value -
                                                   baseline_value)
                        integrated_inputs[feature_name].requires_grad = True

            integrated_outputs = self.model(integrated_inputs)
            integrated_output = self.post_process_output(
                integrated_outputs[target_class_name])

            self.model.zero_grad()
            integrated_output.backward(gradient=one_hot_output,
                                       retain_graph=True)

            for name in input_names_with_gradient:
                if integrated_inputs[name].grad is not None:
                    integrated_gradients[name] += integrated_inputs[name].grad

        # average the gradients and multiply by input
        for name in list(integrated_gradients.keys()):
            integrated_gradients[name] = trw.utils.to_value(
                (inputs[name] - baseline_inputs[name]) *
                integrated_gradients[name] / self.steps)

        logger.info('integrated gradient successful!')
        return target_class_name, integrated_gradients