def outer_prod(q1,q2):
        X1 = vec2mat(q1)
        X2 = torch.movedim(q2,-1,0)
        X1_flat = X1.reshape((-1,4))
        X2_flat = X2.reshape((4,-1))
        X_out = torch.matmul(X1_flat,X2_flat)
        X_out = X_out.reshape(q1.shape + q2.shape[:-1])
        X_out = torch.movedim(X_out,len(q1.shape)-1,-1)
        return X_out
Beispiel #2
0
def torch2np(x: torch.Tensor) -> np.ndarray:
    """Convert PyTorch tensor format (NCHW) to Numpy or TensorFlow format (NHWC)."""
    assert isinstance(x, torch.Tensor)
    x = x.detach().cpu()
    if len(x.shape) == 2:
        x = torch.movedim(x, 1, 0)
    elif len(x.shape) == 3:
        x = torch.movedim(x, 0, 2)
    elif len(x.shape) == 4:
        x = x.permute(0, 2, 3, 1)
    else:
        raise ValueError(f"Not supporting with shape of {len(x.shape)}.")
    return x.numpy()
Beispiel #3
0
def np2torch(x: np.ndarray) -> torch.Tensor:
    """Convert Numpy tensor format (NHWC) to PyTorch format (NCHW)."""
    shape = x.shape
    x = torch.as_tensor(x)
    if len(shape) == 2:
        x = torch.movedim(x, -1, 0)
    elif len(shape) == 3:
        x = torch.movedim(x, -1, 0)
    elif len(shape) == 4:
        x = x.permute(0, 3, 1, 2)
    else:
        raise ValueError(f"Not supporting with shape of {len(shape)}.")
    return x
Beispiel #4
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(2, 4, 2)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return (
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(y, i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
     )
Beispiel #5
0
    def __init__(self, dst, expt_logdir, rows, cols):

        self.dst = dst
        self.expt_logdir = expt_logdir
        self.rows = rows
        self.cols = cols
        self.images = []
        self.images_vis = []
        self.labels_vis = []
        image_ids = np.random.randint(len(dst), size=rows * cols)

        for image_id in image_ids:
            image, label = dst[image_id][0], dst[image_id][1]
            image = image[None, ...]
            self.images.append(image)

            image = torch.squeeze(image)
            image = image * std[:, None, None] + mean[:, None, None]
            image = torch.movedim(image, 0, -1)  # (3,H,W) to (H,W,3)
            image = image.cpu().numpy()
            self.images_vis.append(image)

            label = label.cpu().numpy()
            label = dst.decode_segmap(label)
            self.labels_vis.append(label)

        self.images = torch.cat(self.images, axis=0)
Beispiel #6
0
    def __init__(self,
                 dimensionality,
                 coords,
                 flatten_order,
                 dropout=0.0,
                 persistent=False):
        """
        Create a Fourier feature based 2D positional encoding. The positional encoding will be flattened
        according to the `flatten_order`, hence `forward` assumes that the input `x` is flattened accordingly.

        :param dimensionality: of the encoding
        :param coords: 2D coordinates (y, x) for cartesian or (r, phi) for polar
        :param flatten_order: order use for input flattening
        :param dropout: applied to the positional encoding
        :param persistent: of the positional encoding
        """
        super(PositionalEncoding2D, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        self.d_model = dimensionality

        pe = self._positional_encoding_2D(self.d_model, coords)
        pe = torch.movedim(pe, 0, -1).unsqueeze(0)
        pe = pe[:, flatten_order]

        self.register_buffer('pe', pe, persistent=persistent)
def write_to_dir(model, load_dir, save_dir):
    # change to eval mode
    model.eval()

    # create output dir
    os.makedirs(save_dir, exist_ok=True)

    # loop over the given directory
    for item in sorted(os.listdir(load_dir)):
        if '.png' in item:
            img = cv2.imread(os.path.join(load_dir, item))
            img = cv2.resize(img, (480, 320))
            img = np.moveaxis(img, -1, 0)
            img = torch.from_numpy(img).type(torch.float32)
            out = model(img.unsqueeze(0).to(DEVICE))
            predicted = torch.argmax(out, dim=1)
            predicted = predicted.cpu().data.numpy()[0, :]

            img = torch.movedim(img, 0,
                                -1).type(torch.uint8).numpy().astype(np.float)
            img[:, :, 1] += (predicted * 255 * 0.5).astype(np.uint8)
            img = np.clip(img, 0, 255).astype(np.uint8)
            # img[:, :, 0] += (predicted * 255 * 0.7).astype(np.uint8)

            cv2.imwrite(os.path.join(save_dir, item), img)
Beispiel #8
0
def movedim(x, *a, **k):
    if torch.is_tensor(x):
        return torch.movedim(x, *a, **k)
    elif isinstance(x, io.MappedArray):
        return x.movedim(*a, **k)
    else:
        return x
Beispiel #9
0
    def aodm2dens(self, dm: torch.Tensor, xyz: torch.Tensor) -> torch.Tensor:
        # xyz: (*BR, ndim)
        # dm: (*BD, nkpts, nao, nao)
        # returns: (*BRD)

        nao = dm.shape[-1]
        nkpts = self._kpts.shape[0]
        xyzshape = xyz.shape  # (*BR, ndim)

        # basis: (nkpts, nao, *BR)
        xyz1 = xyz.reshape(-1, xyzshape[-1])  # (BR=ngrid, ndim)
        # ao1: (nkpts, nao, ngrid)
        ao1 = intor.pbc_eval_gto(self._basiswrapper,
                                 xyz1,
                                 kpts=self._kpts,
                                 options=self._lattsum_opt)
        ao1 = torch.movedim(ao1, -1, 0).reshape(*xyzshape[:-1], nkpts,
                                                nao)  # (*BR, nkpts, nao)

        # dens = torch.einsum("...ka,...kb,...kab,k->...", ao1, ao1.conj(), dm, self._wkpts)
        densk = torch.matmul(dm,
                             ao1.conj().unsqueeze(-1))  # (*BRD, nkpts, nao, 1)
        densk = torch.matmul(ao1.unsqueeze(-2),
                             densk).squeeze(-1).squeeze(-1)  # (*BRD, nkpts)
        assert densk.imag.abs().max(
        ) < 1e-9, "The density should be real at this point"

        dens = torch.einsum("...k,k->...", densk.real, self._wkpts)  # (*BRD)
        return dens
Beispiel #10
0
    def preproc_img(self, img_path, mask_spec=None):
        """
        Load an image and a mask.

        :param img_path: Path to target image
        :param mask_spec: Path to mask image or region to mask
        :return: The mask, and a 4x256x256 tensor containing the masked RGB image + the binary mask channel
        """
        if not self.mask_is_rect():
            # open mask and make values binary
            mask = np.array(Image.open(mask_spec).convert("L")) / 255
            mask[mask <= 0.5] = 0.0
            mask[mask > 0.5] = 1.0
        else:
            mask = np.full((256, 256), 1.0)
            mask[mask_spec[0]:mask_spec[0] + mask_spec[2],
                 mask_spec[1]:mask_spec[1] + mask_spec[3]] = 0.0

        # open image and apply mask by making masked pixels black
        im = Image.open(img_path).convert("RGB")
        im = np.array(CustomDataset.scale(im, 256)) / 255
        im[mask == 0.0, :] = 0.0

        sample = torch.cat(
            (torch.tensor(im), torch.tensor(mask.reshape(
                256, 256,
                1))),  # add dimension to match shape of 256x256x3 image
            dim=2)

        # move 'channels' dimension from last to first, as required by PyTorch
        # new shape is 4x256x256
        return torch.Tensor(mask), torch.movedim(sample, 2, 0).float()
Beispiel #11
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(4, 4)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return len(
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(torch.randn(2, 2, 4), i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         x.index(t),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
         torch.row_stack((x, x)),
         torch.select(x, 0, 0),
         torch.scatter(x, 0, t, x),
         x.scatter(0, t, x.clone()),
         torch.diagonal_scatter(y, torch.ones(4)),
         torch.select_scatter(y, torch.ones(4), 0, 0),
         torch.slice_scatter(x, x),
         torch.scatter_add(x, 0, t, x),
         x.scatter_(0, t, y),
         x.scatter_add_(0, t, y),
         # torch.scatter_reduce(x, 0, t, reduce="sum"),
         torch.split(x, 1),
         torch.squeeze(x, 0),
         torch.stack([x, x]),
         torch.swapaxes(x, 0, 1),
         torch.swapdims(x, 0, 1),
         torch.t(x),
         torch.take(x, t),
         torch.take_along_dim(x, torch.argmax(x)),
         torch.tensor_split(x, 1),
         torch.tensor_split(x, [0, 1]),
         torch.tile(x, (2, 2)),
         torch.transpose(x, 0, 1),
         torch.unbind(x),
         torch.unsqueeze(x, -1),
         torch.vsplit(x, i),
         torch.vstack((x, x)),
         torch.where(x),
         torch.where(t > 0, t, 0),
         torch.where(t > 0, t, t),
     )
Beispiel #12
0
    def __next__(self) -> torch.Tensor:
        # Loading image
        succes, image = self.capture.read()
        if succes is False:
            raise StopIteration

        image = torch.tensor(image)
        image = torch.movedim(torch.tensor(image), -1, 0)
        image = rgb_to_grayscale(image).squeeze()
        return image
Beispiel #13
0
def visualize_local_map(filename: str, observations: torch.Tensor):
    observations = observations[:, :3 * 64 * 64] + 0.5
    observations = observations.view(-1, 3, 64, 64)
    observations = torch.movedim(observations, 1, 3).cpu().numpy().astype(
        np.uint8) * 255
    _, H, W, _ = observations.shape
    writer = cv2.VideoWriter(filename, cv2.VideoWriter_fourcc(*'mp4v'), 30.,
                             (W, H), True)
    for observation in observations:
        writer.write(observation)
    writer.release()
    h264_converter(filename)
 def step(self, batch, type='train'):
     frames, t, steering, throttle, force, bat = batch
     steering, throttle = steering.to(torch.float), throttle.to(torch.float)
     x = torch.movedim(frames, -1, 1) / 255
     steering_pred, throttle_pred = self(x).T
     steering_loss = self.loss_func(steering_pred, steering[:, -1])
     throttle_loss = self.loss_func(throttle_pred, throttle[:, -1])
     loss = steering_loss + throttle_loss
     wandb.log({
         type + '_throttle_loss': throttle_loss.item(),
         type + '_steering_loss': steering_loss.item(),
         type + '_loss': loss.item()
     })
     return loss
    def __call__(self, *p_or_args):
        p = p_or_args if len(p_or_args) == 1 else torch.stack(
            torch.broadcast_tensors(*p_or_args), -1)
        assert p.shape[-1] == self.ndim

        p = 2 * (p - self.ranges[:, 0]) / self.extents - 1

        p_flat = p.reshape(*((1, ) * self.ndim), -1, self.ndim)
        data_flat = self.data.unsqueeze(0)

        res = torch.nn.functional.grid_sample(data_flat,
                                              p_flat,
                                              align_corners=True)
        return torch.movedim(res.reshape(self.channels, *p.shape[:-1]), 0, -1)
Beispiel #16
0
 def step(self, batch, step_type='train'):
     frames, t, steering, throttle, force, bat = batch
     steering, throttle = steering.to(torch.float), throttle.to(torch.float)
     imgs = torch.movedim(frames, -1, 1) / 255
     ctrl = torch.cat([steering[:, None], throttle[:, None]], 1)
     steering_pred, throttle_pred = self(imgs, ctrl).T
     steering_loss = self.loss_func(steering_pred, steering[:, -1])
     throttle_loss = self.loss_func(throttle_pred, throttle[:, -1])
     loss = steering_loss + throttle_loss
     wandb.log({
         step_type + '_throttle_loss': throttle_loss.item(),
         step_type + '_steering_loss': steering_loss.item(),
         step_type + '_loss': loss.item()
     })
     return loss
Beispiel #17
0
    def aodm2dens(self, dm: torch.Tensor, xyz: torch.Tensor) -> torch.Tensor:
        # xyz: (*BR, ndim)
        # dm: (*BD, nao, nao)
        # returns: (*BRD)

        nao = dm.shape[-1]
        xyzshape = xyz.shape
        # basis: (nao, *BR)
        basis = intor.eval_gto(self.libcint_wrapper,
                               xyz.reshape(-1, xyzshape[-1])).reshape(
                                   (nao, *xyzshape[:-1]))
        basis = torch.movedim(basis, 0, -1)  # (*BR, nao)

        # torch.einsum("...ij,...i,...j->...", dm, basis, basis)
        dens = torch.matmul(dm, basis.unsqueeze(-1))  # (*BRD, nao, 1)
        dens = torch.matmul(basis.unsqueeze(-2),
                            dens).squeeze(-1).squeeze(-1)  # (*BRD)
        return dens
Beispiel #18
0
def inference(input_dict, *args):
    """
    Function for the actual inference
    """
    if input_dict['is_video']:
        source_image = input_dict['source_image']
        target_video = input_dict['target_pose']
        
        # do the final transfer
        source_image = Image.open(io.BytesIO(base64.b64decode(source_image)))
        target_video_path = 'webapp/static/videos/seq/' + target_video
        frames = [ frames for frames, _ in INFERENCE_PIPELINE.render_video(source_image, target_video_path)]
        frames = torch.cat(frames)
        frames = frames.float()
        frames = torch.movedim(frames, 1, 3)
        frames = (frames + 1) / 2.0 * 255.0
        target_video_file = NamedTemporaryFile(dir='webapp/static/videos/generated/', suffix='.mp4', delete=True)
        torchvision.io.write_video(target_video_file, frames.byte(), fps=30)
        target_video_file.seek(0)
        target_video_file_name = target_video_file.name.split('/')[-1]
        return {'target_video': target_video_file_name}
    else:
        # unpack the input
        source_image = input_dict['source_image']
        target_pose = input_dict['target_pose']

        # TODO: bring the target pose list into the 
        # right input format for the torch model
        # transform the target pose into tensor
        target_pose = torch.Tensor(target_pose)
        
        # get the source pose
        # source_pose = KEYPOINT_MODEL(source_image)

        # get the source segmentation
        # source_segmentation = SEGMENTATION_MODEL(source_image)

        # do the final transfer
        source_image = Image.open(io.BytesIO(base64.b64decode(source_image)))
        target_image = INFERENCE_PIPELINE(source_image, target_pose)
        target_image_file = io.BytesIO()
        target_image.save(target_image_file, format="PNG")
        target_image = base64.b64encode(target_image_file.getvalue()).decode()
        return {'target_image': target_image}
Beispiel #19
0
def _flatten_probas(probas, labels, ignore=None):
    """Flattens predictions in the batch
    """
    if probas.dim() == 3:
        # assumes output of a sigmoid layer
        B, H, W = probas.size()
        probas = probas.view(B, 1, H, W)

    C = probas.size(1)
    probas = torch.movedim(probas, 0, -1)  # [B, C, Di, Dj, Dk...] -> [B, C, Di...Dk, C]
    probas = probas.contiguous().view(-1, C)  # [P, C]

    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = labels != ignore
    vprobas = probas[valid]
    vlabels = labels[valid]
    return vprobas, vlabels
Beispiel #20
0
def move_bdim_to_front(x, result_ndim=None):
    """
    Returns a tensor with a batch dimension at the front. If a batch
    dimension already exists, move it. Otherwise, create a new batch
    dimension at the front. If `result_ndim` is not None, ensure that the
    resulting tensor has rank equal to `result_ndim`.
    """
    x_dim = len(x.shape)
    x_bdim = x.bdim
    if x_bdim is None:
        x = torch.unsqueeze(x, 0)
    else:
        x = torch.movedim(x, x_bdim, 0)
    if result_ndim is None:
        return x
    diff = result_ndim - x_dim - (x_bdim is None)
    for _ in range(diff):
        x = torch.unsqueeze(x, 1)
    return x
Beispiel #21
0
 def __getitem__(self, idx):
     if torch.is_tensor(idx):
         idx = idx.tolist()
     image_name = os.path.join(self.root_dir, self.labels.iloc[idx,
                                                               0].strip())
     image = torch.from_numpy(io.imread(image_name))
     image = torch.movedim(image, 2, 0).float()
     '''
     Generate the correct labels depending on the task.
       classifier = laser on/off
       regressor  = theta & distance
     '''
     if self.kind == 'classifier':
         laser = self.labels.iloc[idx, 4:5]
         laser = torch.from_numpy(np.array(laser, dtype=np.int16))
         label = laser.float()
     elif self.kind == 'regressor':
         reg = self.labels.iloc[idx, 2:4]
         reg = torch.from_numpy(np.array(reg, dtype=np.float32))
         label = reg.float()
     return image, label
Beispiel #22
0
    def backward(
        ctx, grad_res: torch.Tensor
    ) -> Tuple[Optional[torch.Tensor], ...]:  # type: ignore
        # grad_res: (*, nao, ngrid)
        ao_to_atom, wrapper, shortname = ctx.other_info
        alphas, coeffs, pos, rgrid = ctx.saved_tensors

        # TODO: implement the gradient w.r.t. alphas and coeffs
        grad_alphas = None
        grad_coeffs = None

        # calculate the gradient w.r.t. basis' pos and rgrid
        grad_pos = None
        grad_rgrid = None
        if rgrid.requires_grad or pos.requires_grad:
            opsname = _get_evalgto_derivname(shortname, "r")
            dresdr = _EvalGTO.apply(*ctx.saved_tensors, ao_to_atom, wrapper,
                                    opsname)  # (ndim, *, nao, ngrid)
            grad_r = dresdr * grad_res  # (ndim, *, nao, ngrid)

            if rgrid.requires_grad:
                grad_rgrid = grad_r.reshape(dresdr.shape[0], -1,
                                            dresdr.shape[-1])
                grad_rgrid = grad_rgrid.sum(dim=1).transpose(
                    -2, -1)  # (ngrid, ndim)

            if pos.requires_grad:
                grad_rao = torch.movedim(grad_r, -2,
                                         0)  # (nao, ndim, *, ngrid)
                grad_rao = -grad_rao.reshape(*grad_rao.shape[:2], -1).sum(
                    dim=-1)  # (nao, ndim)
                grad_pos = torch.zeros_like(pos)  # (natom, ndim)
                grad_pos.scatter_add_(dim=0, index=ao_to_atom, src=grad_rao)

        return grad_alphas, grad_coeffs, grad_pos, grad_rgrid, \
            None, None, None, None, None
Beispiel #23
0
    print(t)

    contiguous_example()

    print(torch.as_strided(t, (2, 2), (1, 2)))
    print(torch.diagonal(t, 0))
    print(torch.diagonal(t, 1))
    print(torch.diagonal(t, -1))

    x = torch.tensor([[1], [2], [3]])
    print(x.expand(3, 4))
    print(x.expand(-1, 4))  # -1 means not changing the size of that dimension
    """re-shape"""
    x = torch.randn(3, 2, 1)
    print(x.shape)
    print(torch.movedim(x, 1, 0).shape)
    print('unflatten',
          x.unflatten(1, (1, 2)).shape)  # torch.Size([3, 1, 2, 1]), 第2维变成 1*2
    print(t.view(16).shape)  # torch.Size([16])
    print(t.view(-1, 8).shape)  # torch.Size([2, 8]), -1表示该维度计算来
    print(torch.flatten(t))
    # print(torch.ravel(t))

    print(t)
    print(torch.narrow(t, 0, 0, 2))  # 按行,取前两行
    print(torch.narrow(t, 1, 0, 2))  # 按列,取前两列
    print(torch.select(t, 0, 1))  # 按行,取第两行
    print(torch.select(t, 1, 1))  # 按列,取第两列
    print(torch.index_select(t, 0, torch.tensor([0, 2])))  # 按行取
    print(torch.index_select(t, 1, torch.tensor([0, 2])))  # 按列取
    print('unbind')
Beispiel #24
0
def movedim_batching_rule(x, from_dim, to_dim):
    x = move_bdim_to_front(x)
    return torch.movedim(x, from_dim + 1, to_dim + 1), 0
Beispiel #25
0
def _get_integrals(int_names: List[str],
                   wrappers: List[LibcintWrapper],
                   int_type: str,
                   int_fcn: Callable[[List[LibcintWrapper], str], torch.Tensor],
                   new_axes_pos: Optional[List[int]] = None) \
                   -> List[torch.Tensor]:
    # Return the list of tensors of the integrals given by the list of integral names.
    # Int_fcn is the integral function that receives the name and returns the results.
    # If new_axes_pos is specified, then move the new axes to 0, otherwise, just leave
    # it as it is

    res: List[torch.Tensor] = []
    # indicating if the integral is available in the libcint-generated file
    int_avail: List[bool] = [False] * len(int_names)

    for i in range(len(int_names)):
        res_i: Optional[torch.Tensor] = None

        # check if the integral can be calculated from the previous results
        for j in range(i - 1, -1, -1):

            # check the integral names equivalence
            transpose_path = _intgl_shortname_equiv(int_names[j], int_names[i],
                                                    int_type)
            if transpose_path is not None:

                # if the swapped wrappers remain unchanged, then just use the
                # transposed version of the previous version
                # TODO: think more about this (do we need to use different
                # transpose path? e.g. transpose_path[::-1])
                twrappers = _swap_list(wrappers, transpose_path)
                if twrappers == wrappers:
                    res_i = _transpose(res[j], transpose_path)
                    break

                # otherwise, use the swapped integral with the swapped wrappers,
                # only if the integral is available in the libcint-generated
                # files
                elif int_avail[j]:
                    res_i = int_fcn(twrappers, int_names[j])
                    res_i = _transpose(res_i, transpose_path)
                    break

                # if the integral is not available, then continue the searching
                else:
                    continue

        if res_i is None:
            try:
                # successfully executing the line below indicates that the integral
                # is available in the libcint-generated files
                res_i = int_fcn(wrappers, int_names[i])
            except AttributeError:
                msg = "The integral %s is not available from libcint, please add it" % int_names[
                    i]
                raise AttributeError(msg)

            int_avail[i] = True

        res.append(res_i)

    # move the new axes to dimension 0
    if new_axes_pos is not None:
        res = [torch.movedim(r, ax, 0) for (r, ax) in zip(res, new_axes_pos)]
    return res
Beispiel #26
0
def to_numpy(tensor):
    array = torch.movedim(tensor, 1, -1).cpu().detach().numpy()
    return array
Beispiel #27
0
    def backward(
        ctx, grad_res: torch.Tensor
    ) -> Tuple[Optional[torch.Tensor], ...]:  # type: ignore
        # grad_res: (*, nao, ngrid)
        ao_to_atom, wrapper, shortname, to_transpose = ctx.other_info
        coeffs, alphas, pos, rgrid = ctx.saved_tensors

        if to_transpose:
            grad_res = grad_res.transpose(-2, -1)

        grad_alphas = None
        grad_coeffs = None
        if alphas.requires_grad or coeffs.requires_grad:
            u_wrapper, uao2ao = wrapper.get_uncontracted_wrapper()
            u_coeffs, u_alphas, u_pos = u_wrapper.params
            # (*, nu_ao, ngrid)
            u_grad_res = _gather_at_dims(grad_res, mapidxs=[uao2ao], dims=[-2])

            # get the scatter indices
            ao2shl = u_wrapper.ao_to_shell()

            # calculate the gradient w.r.t. coeffs
            if coeffs.requires_grad:
                grad_coeffs = torch.zeros_like(coeffs)  # (ngauss)

                # get the uncontracted version of the integral
                # (..., nu_ao, ngrid)
                dout_dcoeff = _EvalGTO.apply(*u_wrapper.params, rgrid,
                                             ao_to_atom, u_wrapper, shortname,
                                             False)

                # get the coefficients and spread it on the u_ao-length tensor
                coeffs_ao = torch.gather(coeffs, dim=-1,
                                         index=ao2shl)  # (nu_ao)
                dout_dcoeff = dout_dcoeff / coeffs_ao[:, None]
                grad_dcoeff = torch.einsum("...ur,...ur->u", u_grad_res,
                                           dout_dcoeff)  # (nu_ao)

                grad_coeffs.scatter_add_(dim=-1, index=ao2shl, src=grad_dcoeff)

            if alphas.requires_grad:
                grad_alphas = torch.zeros_like(alphas)

                new_sname = _get_evalgto_derivname(shortname, "a")
                # (..., nu_ao, ngrid)
                dout_dalpha = _EvalGTO.apply(*u_wrapper.params, rgrid,
                                             ao_to_atom, u_wrapper, new_sname,
                                             False)

                alphas_ao = torch.gather(alphas, dim=-1,
                                         index=ao2shl)  # (nu_ao)
                grad_dalpha = -torch.einsum("...ur,...ur->u", u_grad_res,
                                            dout_dalpha)

                grad_alphas.scatter_add_(dim=-1, index=ao2shl, src=grad_dalpha)

        # calculate the gradient w.r.t. basis' pos and rgrid
        grad_pos = None
        grad_rgrid = None
        if rgrid.requires_grad or pos.requires_grad:
            opsname = _get_evalgto_derivname(shortname, "r")
            dresdr = _EvalGTO.apply(*ctx.saved_tensors, ao_to_atom, wrapper,
                                    opsname, False)  # (ndim, *, nao, ngrid)
            grad_r = dresdr * grad_res  # (ndim, *, nao, ngrid)

            if rgrid.requires_grad:
                grad_rgrid = grad_r.reshape(dresdr.shape[0], -1,
                                            dresdr.shape[-1])
                grad_rgrid = grad_rgrid.sum(dim=1).transpose(
                    -2, -1)  # (ngrid, ndim)

            if pos.requires_grad:
                grad_rao = torch.movedim(grad_r, -2,
                                         0)  # (nao, ndim, *, ngrid)
                grad_rao = -grad_rao.reshape(*grad_rao.shape[:2], -1).sum(
                    dim=-1)  # (nao, ndim)
                grad_pos = torch.zeros_like(pos)  # (natom, ndim)
                grad_pos.scatter_add_(dim=0, index=ao_to_atom, src=grad_rao)

        return grad_coeffs, grad_alphas, grad_pos, grad_rgrid, \
            None, None, None, None, None, None
Beispiel #28
0
def moveaxis(x: NdarrayOrTensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]) -> NdarrayOrTensor:
    """`moveaxis` for pytorch and numpy"""
    if isinstance(x, torch.Tensor):
        return torch.movedim(x, src, dst)  # type: ignore
    return np.moveaxis(x, src, dst)
Beispiel #29
0
def to_tensor(array):
    tensor = torch.from_numpy(array).float()
    tensor = torch.movedim(tensor, -1, 1)
    return tensor
Beispiel #30
0
def main(options):

    # find readout direction
    f = io.map(options.echoes[0])
    affine, shape = f.affine, f.shape
    readout = get_readout(options.direction, affine, shape, options.verbose)

    if not options.reversed:
        reversed_echoes = options.synth
    else:
        reversed_echoes = options.reversed

    # do EPIC
    fit = epic(options.echoes,
               reverse_echoes=reversed_echoes,
               fieldmap=options.fieldmap,
               extrapolate=options.extrapolate,
               bandwidth=options.bandwidth,
               polarity=options.polarity,
               readout=readout,
               slicewise=options.slicewise,
               lam=options.penalty,
               max_iter=options.maxiter,
               tol=options.tolerance,
               verbose=options.verbose,
               device=get_device(options.gpu))

    # save volumes
    input, output = options.echoes, options.output
    if len(output) != len(input):
        if len(output) == 1:
            if '{base}' in output[0]:
                output = [output[0]] * len(input)
        elif len(output) != len(fit):
            raise ValueError(f'There should be either one output file, '
                             f'or as many output files as input files, '
                             f'or as many output files as echoes. Got '
                             f'{len(output)} output files, {len(input)} '
                             f'input files, and {len(fit)} echoes.')
    if len(output) == 1:
        dir, base, ext = py.fileparts(input[0])
        output = output[0]
        if '{n}' in output:
            for n, echo in enumerate(fit):
                out = output.format(dir=dir,
                                    sep=os.sep,
                                    base=base,
                                    ext=ext,
                                    n=n)
                io.savef(echo, out, like=input[0])
        else:
            output = output.format(dir=dir, sep=os.sep, base=base, ext=ext)
            io.savef(torch.movedim(fit, 0, -1), output, like=input[0])
    elif len(output) == len(input):
        for i, (inp, out) in enumerate(zip(input, output)):
            dir, base, ext = py.fileparts(inp)
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=i)
            ne = [*io.map(inp).shape, 1][3]
            io.savef(fit[:ne].movedim(0, -1), out, like=inp)
            fit = fit[ne:]
    else:
        assert len(output) == len(fit)
        dir, base, ext = py.fileparts(input[0])
        for n, (echo, out) in enumerate(zip(fit, output)):
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n)
            io.savef(echo, out, like=input[0])