예제 #1
0
    def __init__(self, root, filename=None):
        super().__init__()

        # Load pre-downloaded dataset
        filename = "3dshapes.npz" if filename is None else filename
        path = pathlib.Path(root, filename)

        with np.load(path) as dataset:
            # Load data and permute dims NHWC -> NCHW
            self.data = torch.tensor(dataset["images"]).permute(0, 3, 1, 2)

        self.factor_sizes = [10, 10, 10, 8, 4, 15]
        self.targets = torch.cartesian_prod(
            *[torch.arange(v) for v in self.factor_sizes])
예제 #2
0
def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
                   shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]:
    """Compute pairs of atoms that are neighbors

    Arguments:
        padding_mask (:class:`torch.Tensor`): boolean tensor of shape
            (molecules, atoms) for padding mask. 1 == is padding.
        coordinates (:class:`torch.Tensor`): tensor of shape
            (molecules, atoms, 3) for atom coordinates.
        cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
            defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
        cutoff (float): the cutoff inside which atoms are considered pairs
        shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
    """
    coordinates = coordinates.detach()
    cell = cell.detach()
    num_atoms = padding_mask.shape[1]
    num_mols = padding_mask.shape[0]
    all_atoms = torch.arange(num_atoms, device=cell.device)

    # Step 2: center cell
    # torch.triu_indices is faster than combinations
    p12_center = torch.triu_indices(num_atoms, num_atoms, 1, device=cell.device)
    shifts_center = shifts.new_zeros((p12_center.shape[1], 3))

    # Step 3: cells with shifts
    # shape convention (shift index, molecule index, atom index, 3)
    num_shifts = shifts.shape[0]
    all_shifts = torch.arange(num_shifts, device=cell.device)
    prod = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).t()
    shift_index = prod[0]
    p12 = prod[1:]
    shifts_outside = shifts.index_select(0, shift_index)

    # Step 4: combine results for all cells
    shifts_all = torch.cat([shifts_center, shifts_outside])
    p12_all = torch.cat([p12_center, p12], dim=1)
    shift_values = shifts_all.to(cell.dtype) @ cell

    # step 5, compute distances, and find all pairs within cutoff
    selected_coordinates = coordinates.index_select(1, p12_all.view(-1)).view(num_mols, 2, -1, 3)
    distances = (selected_coordinates[:, 0, ...] - selected_coordinates[:, 1, ...] + shift_values).norm(2, -1)
    padding_mask = padding_mask.index_select(1, p12_all.view(-1)).view(2, -1).any(0)
    distances.masked_fill_(padding_mask, math.inf)
    in_cutoff = (distances <= cutoff).nonzero()
    molecule_index, pair_index = in_cutoff.unbind(1)
    molecule_index *= num_atoms
    atom_index12 = p12_all[:, pair_index]
    shifts = shifts_all.index_select(0, pair_index)
    return molecule_index + atom_index12, shifts
예제 #3
0
def grid_visualize(point_clouds,
                   encoder,
                   decoder,
                   grid_scale,
                   threshold,
                   knn_num,
                   save_dir,
                   batch_idx=0):
    B, C, N = point_clouds.shape
    device = point_clouds.device
    with torch.no_grad():
        scale = torch.linspace(-1.0, 1.0, grid_scale)
        point_grid = torch.stack(
            [torch.cartesian_prod(scale, scale, scale).transpose(1, 0)] * B,
            dim=0).to(device)
        partial_size = 100
        test_pred = torch.Tensor([]).to(device)
        for i in range(int((grid_scale**3) / partial_size)):
            partial_point_grid = point_grid[:, :, i * partial_size:(i + 1) *
                                            partial_size]
            temp_latent_vector = encoder(
                knn_point_sampling(point_clouds, partial_point_grid, knn_num))
            test_pred = torch.cat([
                test_pred,
                decoder(partial_point_grid, temp_latent_vector).squeeze(dim=-1)
            ],
                                  dim=2)
        for b in range(B):
            test_pred_sample = test_pred[b, :, :]
            masked_index = (test_pred_sample.squeeze() < threshold).nonzero()
            pred_pc = torch.gather(point_grid[b, :, :], 1, torch.stack([masked_index.squeeze()] * 3, dim=0)) \
                .unsqueeze(dim=0)
            if pred_pc.size(2) > N:
                pred_pc, _ = pcu.random_point_sample(pred_pc, N)
            elif pred_pc.size(2) < N:
                new_pred_pc = pred_pc
                while new_pred_pc.size(2) < N:
                    new_pred_pc = torch.cat(
                        [new_pred_pc, pcu.jitter(pred_pc)], dim=2)
                pred_pc, _ = pcu.random_point_sample(new_pred_pc, N)
            # pcu.visualize(point_clouds)
            # pcu.visualize(pred_pc)
            image_save(pred_pc.detach().cpu(),
                       save_dir,
                       'grid_visualize',
                       'prediction',
                       'predict_pc',
                       batch_idx=batch_idx * B + b,
                       folder_numbering=False)
예제 #4
0
    def __init__(self, check_on_x, check_on_y, check_on_t, check_every):
        self.using_non_gui_backend = matplotlib.get_backend() == 'agg'

        xy_tensor = torch.cartesian_prod(check_on_x, check_on_y)
        self.xx_tensor = torch.squeeze(xy_tensor[:, 0])
        self.yy_tensor = torch.squeeze(xy_tensor[:, 1])
        self.tt_tensors = [torch.ones(len(xy_tensor)) * t for t in check_on_t]
        self.xx_array = self.xx_tensor.clone().detach().cpu().numpy()
        self.yy_array = self.yy_tensor.clone().detach().cpu().numpy()
        self.t_array = check_on_t.clone().detach().cpu().numpy()
        self.check_every = check_every

        self.fig = None
        self.axs = []  # subplots
        self.cbs = []  # color bars
예제 #5
0
    def forward(self, x: torch.Tensor):
        # get params
        stride = int(x.size(-1) * ((self.max_stride_ratio - self.min_stride_ratio) * random() + self.min_stride_ratio))
        grid_size = randint(int(stride * 0.3), int(stride * 0.7))

        grid_hws = torch.cartesian_prod(
            torch.arange(randint(0, stride), x.size(1), stride),
            torch.arange(randint(0, stride), x.size(2), stride)
        )

        for h, w in grid_hws:
            if torch.rand(1) < self.p:
                erase = torch.empty_like(x[..., h:h+grid_size, w:w+grid_size], dtype=torch.float32).normal_()
                x[..., h:h+grid_size, w:w+grid_size] = erase
        return x
예제 #6
0
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
    """Compute pairs of atoms that are neighbors

    Arguments:
        padding_mask (:class:`torch.Tensor`): boolean tensor of shape
            (molecules, atoms) for padding mask. 1 == is padding.
        coordinates (:class:`torch.Tensor`): tensor of shape
            (molecules, atoms, 3) for atom coordinates.
        cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
            defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
        cutoff (float): the cutoff inside which atoms are considered pairs
        shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
    """
    # type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

    coordinates = coordinates.detach()
    cell = cell.detach()
    num_atoms = padding_mask.shape[1]
    all_atoms = torch.arange(num_atoms, device=cell.device)

    # Step 2: center cell
    p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
    shifts_center = shifts.new_zeros(p1_center.shape[0], 3)

    # Step 3: cells with shifts
    # shape convention (shift index, molecule index, atom index, 3)
    num_shifts = shifts.shape[0]
    all_shifts = torch.arange(num_shifts, device=cell.device)
    shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1)
    shifts_outide = shifts.index_select(0, shift_index)

    # Step 4: combine results for all cells
    shifts_all = torch.cat([shifts_center, shifts_outide])
    p1_all = torch.cat([p1_center, p1])
    p2_all = torch.cat([p2_center, p2])
    shift_values = torch.mm(shifts_all.to(cell.dtype), cell)

    # step 5, compute distances, and find all pairs within cutoff
    distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all) + shift_values).norm(2, -1)
    padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all))
    distances.masked_fill_(padding_mask, math.inf)
    in_cutoff = (distances <= cutoff).nonzero()
    molecule_index, pair_index = in_cutoff.unbind(1)
    molecule_index *= num_atoms
    atom_index1 = p1_all[pair_index]
    atom_index2 = p2_all[pair_index]
    shifts = shifts_all.index_select(0, pair_index)
    return molecule_index + atom_index1, molecule_index + atom_index2, shifts
예제 #7
0
def test_segment_cartesian_product():
    rng = np.random.RandomState(42)

    va = torch.tensor(rng.randn(10, 2))
    vb = torch.tensor(rng.randn(8, 2))

    sa = torch.tensor(_scopes_from_lengths([3, 7]))
    sb = torch.tensor(_scopes_from_lengths([2, 6]))

    result = index.segment_cartesian_product(va, vb, sa, sb)

    i1 = torch.cartesian_prod(torch.arange(sa[0, 1]), torch.arange(sb[0, 1]))
    r1 = torch.cat(
        (va.index_select(0, i1[:, 0]), vb.index_select(0, i1[:, 1])), dim=-1)

    assert np.allclose(r1.cpu().numpy(), result[:r1.shape[0], :])
예제 #8
0
    def __init__(self, check_on_x, check_on_y, check_every):
        self.using_non_gui_backend = matplotlib.get_backend() == 'agg'

        xy_tensor = torch.cartesian_prod(check_on_x, check_on_y)
        self.xx_tensor = torch.squeeze(xy_tensor[:, 0])
        self.yy_tensor = torch.squeeze(xy_tensor[:, 1])

        self.xx_array = self.xx_tensor.clone().detach().cpu().numpy()
        self.yy_array = self.yy_tensor.clone().detach().cpu().numpy()

        self.check_every = check_every

        self.fig = plt.figure(figsize=(30, 8))
        self.ax1 = self.fig.add_subplot(131)
        self.cb1 = None
        self.ax2 = self.fig.add_subplot(132)
        self.ax3 = self.fig.add_subplot(133)
예제 #9
0
    def _build_spherical_perspective(self, fov: float, width: int,
                                     height: int) -> Rays:
        """
        Build fish-eye array of rays
        """

        aspect_ratio = width / height

        # Field of view in radians
        horizontal_fov = fov * math.pi / 180.0
        vertical_fov = horizontal_fov / aspect_ratio

        # Spherical coordinates
        phi = torch.linspace(vertical_fov / 2,
                             -vertical_fov / 2,
                             height,
                             dtype=self.float_dtype)  # Declination
        theta = torch.linspace(-horizontal_fov / 2,
                               horizontal_fov / 2,
                               width,
                               dtype=self.float_dtype)  # Yaw
        phi_theta = torch.cartesian_prod(phi, theta)

        # Source: origin
        src = self._as_float_tensor([.0, .0, .0])

        # Destination
        dst = torch.cat(
            [(self.ray_len * torch.sin(phi_theta[:, 1]) *
              torch.cos(phi_theta[:, 0])).unsqueeze(-1),
             (self.ray_len * torch.sin(phi_theta[:, 0])).unsqueeze(-1),
             (-self.ray_len * torch.cos(phi_theta[:, 1]) *
              torch.cos(phi_theta[:, 0])).unsqueeze(-1)],
            dim=1)

        rays = self._allocate_rays(width * height)
        weight = 1.

        rays.src = src
        rays.dst = dst
        rays.wgt = weight
        rays.col = self._as_float_tensor(self.ray_col.asarray())

        return rays
예제 #10
0
    def _get_pixelwise_anchors(
            conv_shape, conv_stride_prod, raw_anchors, *, pixel_center=0.5):
        assert(len(conv_shape) == raw_anchors.shape[1])
        ranges = (
            (
                torch.arange(
                    shape_dim,
                    device=raw_anchors.device,
                    dtype=raw_anchors.dtype)
                + pixel_center
            ) * stride_prod_dim
            for shape_dim, stride_prod_dim
            in zip(conv_shape, conv_stride_prod))

        shifts = torch.cartesian_prod(*ranges).unsqueeze(-2)
        broadcasted_tensors = torch.broadcast_tensors(shifts, raw_anchors)
        stacked_tensors = torch.stack(broadcasted_tensors, axis=-2)

        return stacked_tensors
예제 #11
0
    def sample(self, x, edge_index, edge_weight, labels, idx_train, idx_val,
               idx_test):
        # pass
        sampled_edge_index: torch.Tensor = edge_index
        sampled_edge_weight: torch.Tensor = edge_weight

        embedding, _ = self.gcn_model(x, edge_index, edge_weight)
        embedding = F.softmax(embedding, dim=1)

        print(embedding.min(), embedding.max())

        n = x.shape[0]
        m = edge_index.shape[-1]
        size = int(np.sqrt(m // 2))
        entropy = torch.distributions.Categorical(probs=embedding).entropy()
        com_ent = torch.clamp_min(np.log(self.args.num_classes) - entropy, 0)
        print('com_ent.min(), com_ent.max()', com_ent.min(), com_ent.max())
        centroid = torch.multinomial(com_ent, size, replacement=False)
        borderline = torch.multinomial(entropy, size, replacement=False)

        non_exist_edge_index = torch.cartesian_prod(centroid, borderline).t()
        non_exist_edge_index = torch.cat(
            [non_exist_edge_index,
             torch.flip(non_exist_edge_index, dims=[0])],
            dim=1)

        sampled_edge_index = torch.cat([edge_index, non_exist_edge_index],
                                       dim=1)
        sampled_edge_weight = torch.cat([
            edge_weight,
            torch.zeros(non_exist_edge_index.shape[1],
                        dtype=torch.float,
                        device=self.args.device)
        ])

        sampled_edge_index, sampled_edge_weight = torch_sparse.coalesce(
            sampled_edge_index, sampled_edge_weight, n, n, 'add')

        mask = sampled_edge_index[0] != sampled_edge_index[1]
        sampled_edge_index = sampled_edge_index[:, mask]
        sampled_edge_weight = sampled_edge_weight[mask]

        return sampled_edge_index, sampled_edge_weight
예제 #12
0
    def test_compute_metric(self):
        # Dataset
        factor_sizes = [5] * 5
        batch_size = torch.prod(torch.tensor(factor_sizes))

        dataset = BaseDataset()
        dataset.data = torch.rand(batch_size, 1, 64, 64)
        dataset.targets = torch.cartesian_prod(
            *[torch.arange(v) for v in factor_sizes])
        dataset.factor_sizes = factor_sizes
        self.evaluator.dataset = dataset

        # Model
        self.evaluator.model = TmpModel()

        # Compute metric
        scores_dict = self.evaluator.compute_metric("mig")

        self.assertIsInstance(scores_dict, dict)
        self.assertTrue(0 <= scores_dict["discrete_mig"] <= 1)
예제 #13
0
def product(*inputs, r=1):
    """Cartesian product of a set.

    Parameters
    ----------
    inputs : iterable of tensor_like
        Input sets (tensors are flattened is not vectors)
        with `n[i]` elements each.
    r : int, default=1
        Repeats.
        .. warning:: keyword argument only.

    Returns
    -------
    output : (prod(n)**r, r) tensor
        Cartesian product

    """
    inputs = [torch.as_tensor(input) for input in inputs]
    return torch.cartesian_prod(*(inputs * r))
예제 #14
0
    def test_perspective_n_points(self, print_stats=False):
        if print_stats:
            print("RUN ON A DENSE GRID")
        u = torch.linspace(-1.0, 1.0, 20)
        v = torch.linspace(-1.0, 1.0, 15)
        for skip_q in [False, True]:
            self._testcase_from_2d(
                torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q)

        for num_pts in range(6, 3, -1):
            for skip_q in [False, True]:
                if print_stats:
                    print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}")

                self.case_with_gaussian_points(
                    num_pts=num_pts,
                    print_stats=print_stats,
                    benchmark=False,
                    skip_q=skip_q,
                )
예제 #15
0
def glass_blur(inp, blur_std, rad, trials):
    inp = gaussian_blur(inp, blur_std)

    batch_size, _, w, h = inp.shape
    coords = ch.cartesian_prod(ch.arange(batch_size),
                               rad + ch.arange(w - 2 * rad),
                               rad + ch.arange(h - 2 * rad)).T
    for _ in range(trials):
        # should probably be (-rad, rad+1) but this is how it's done in ImageNet-C
        xy_diffs = ch.randint(-rad, rad, (2, coords.shape[1]))
        new_coords = coords + ch.cat([ch.zeros(1, coords.shape[1]), xy_diffs
                                      ]).long()

        # Swap coords and new_coords
        (b1, x1, y1), (b2, x2, y2) = coords, new_coords
        cp1, cp2 = inp[b1, :, x1, y1].clone(), inp[b2, :, x2, y2].clone()
        inp[b2, :, x2, y2] = cp1
        inp[b1, :, x1, y1] = cp2

    return ch.clamp(gaussian_blur(inp, blur_std), 0, 1)
예제 #16
0
    def eval_step(self, batch, batch_idx, tag):
        X, y, g, rows = batch

        y_hat = self.model(X, g)
        assert (y.size() == y_hat.size())

        out_dim = y.size(-1)

        index_ptr = torch.cartesian_prod(torch.arange(rows.size(0)),
                                         torch.arange(g['cent_n_id'].size(0)),
                                         torch.arange(out_dim))

        label = pd.DataFrame({
            'row_idx':
            rows[index_ptr[:, 0]].data.cpu().numpy(),
            'node_idx':
            g['cent_n_id'][index_ptr[:, 1]].data.cpu().numpy(),
            'feat_idx':
            index_ptr[:, 2].data.cpu().numpy(),
            'val':
            y[index_ptr.t().chunk(3)].squeeze(dim=0).data.cpu().numpy()
        })

        pred = pd.DataFrame({
            'row_idx':
            rows[index_ptr[:, 0]].data.cpu().numpy(),
            'node_idx':
            g['cent_n_id'][index_ptr[:, 1]].data.cpu().numpy(),
            'feat_idx':
            index_ptr[:, 2].data.cpu().numpy(),
            'val':
            y_hat[index_ptr.t().chunk(3)].squeeze(dim=0).data.cpu().numpy()
        })

        pred = pred.groupby(['row_idx', 'node_idx', 'feat_idx']).mean()
        label = label.groupby(['row_idx', 'node_idx', 'feat_idx']).mean()

        return {
            'label': label,
            'pred': pred,
        }
예제 #17
0
 def plot2d_pdf(self, ax, bounds=((-6, 6), (-6, 6)), n_points=1000):
     bounds_x = bounds[0]
     bounds_y = bounds[1]
     x = torch.linspace(*bounds_x,
                        n_points,
                        device=self.device,
                        dtype=torch.float)
     y = torch.linspace(*bounds_y,
                        n_points,
                        device=self.device,
                        dtype=torch.float)
     if self.cached_grid is not None:
         levels = self.cached_grid
     else:
         levels = torch.exp(
             -self.E(torch.cartesian_prod(x, y).view(-1, 2))).view(
                 n_points, n_points).data.cpu().numpy()
         self.cached_grid = levels
     levels /= np.sum(levels)
     ax.contour(x, y, levels.T)
     ax.set_title(self.name, fontsize=16)
예제 #18
0
    def post_epoch_visualize(self, epoch, split):
        print('* Visualizing', split)
        Z = torch.linspace(0.0, 1.0, steps=20)
        Z = torch.cartesian_prod(Z, Z).view(20, 20, 2)
        x_gens = []
        for row in range(20):
            if self.flags.z_size == 2:
                z = Z[row]
            else:
                z = torch.rand(20, self.flags.z_size)
            z = self.model.prepare_batch(z)
            x_gen = self.model.run_batch([z], visualize=True).detach().cpu()
            x_gens.append(x_gen)

        x_full = torch.cat(x_gens, dim=0).numpy()
        if split == 'test':
            fname = self.flags.log_dir + '/test.png'
        else:
            fname = self.flags.log_dir + '/vis_%03d.png' % self.model.get_train_steps()
        misc.save_comparison_grid(fname, x_full, border_width=0, retain_sequence=True)
        print('* Visualizations saved to', fname)
예제 #19
0
def gen_vtx_points(base_lb: Tensor, base_ub: Tensor, extra: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
    """ Generate the vertices of a hyper-rectangle bounded by LB/UB.
    :param base_lb/base_ub: batched
    :return: Batch x N x State, where N is the number of vertices in each abstraction
    """
    # TODO a faster way might be using torch.where() to select from LB/UB points with [0, 2^n-1] indices
    all_vtxs = []
    for lb, ub in zip(base_lb, base_ub):
        # basically, a cartesian product of LB/UB on each dimension
        lbub = torch.stack((lb, ub), dim=-1)  # Dim x 2
        vtxs = torch.cartesian_prod(*list(lbub))
        all_vtxs.append(vtxs)
    all_vtxs = torch.stack(all_vtxs, dim=0)

    if extra is None:
        new_extra = None
    else:
        new_size = list(extra.size())
        new_size.insert(1, all_vtxs.shape[1])
        new_extra = extra.unsqueeze(dim=1).expand(*new_size)
    return all_vtxs, new_extra
예제 #20
0
def generator_3dspatial_cube(size,
                             x_min,
                             x_max,
                             y_min,
                             y_max,
                             z_min,
                             z_max,
                             random=True):
    """Return a generator that generates 3D points in a cube.

    :param size:
        Number of points to generated when `__next__` is invoked.
    :type size: int
    :param start:
        The starting point of the line segment.
    :type start: tuple[float, float]
    :param end:
        The ending point of the line segment.
    :type end: tuple[float, float]
    :param random:
        - If set to False, then return eqally spaced points range from `start` to `end`.
        - If set to Rrue then generate points randomly.

        Defaults to True.
    :type random: bool
    """
    x_size, y_size, z_size = size
    x_generator = generator_1dspatial(x_size, x_min, x_max, random)
    y_generator = generator_1dspatial(y_size, y_min, y_max, random)
    z_generator = generator_1dspatial(z_size, z_min, z_max, random)
    while True:
        x = next(x_generator)
        y = next(y_generator)
        z = next(z_generator)
        xyz = torch.cartesian_prod(x, y, z)
        xx = torch.squeeze(xyz[:, 0])
        yy = torch.squeeze(xyz[:, 1])
        zz = torch.squeeze(xyz[:, 2])
        yield xx, yy, zz
def main():
    # args
    parser = ArgumentParser()
    # trainer args
    parser.add_argument('cfg_path', type=str)
    parser.add_argument('ear_path', type=str)
    parser.add_argument('output_path', type=str)
    parser.add_argument('--device', default='cpu', type=str)
    parser.add_argument('--nfft', default=256, type=int)
    parser.add_argument('--sr', default=44100, type=int)
    parser.add_argument('--view', action='store_true')
    # parse
    args = parser.parse_args()

    # load configs
    with open(args.cfg_path, 'r') as fp:
        cfg = json.load(fp)
    img_size = cfg['ears']['img_size']
    img_channels = cfg['ears']['img_channels']

    # pick models
    EarsModelClass = {
        'vae_conv': VAECfg,
        'vae_resnet': ResNetVAECfg,
        'vae_incept': InceptionVAECfg
    }.get(cfg['ears']['model_type'])
    LatentModelClass = {
        'dnn': DNNCfg
    }.get(cfg['latent']['model_type'])
    HrtfModelClass = {
        'cvae_dense': CVAECfg,
    }.get(cfg['hrtf']['model_type'])

    # load models
    models = {}
    for ModelClass, model_type in zip([EarsModelClass, LatentModelClass, HrtfModelClass], ['ears', 'latent', 'hrtf']):
        model_ckpt_path = cfg[model_type]['model_ckpt_path']
        print(f'### Loading model {ModelClass.model_name} from {model_ckpt_path}...')
        model = ModelClass.load_from_checkpoint(model_ckpt_path)
        model.to(args.device)
        model.eval()
        models[model_type] = model
    print('### Models Loaded.')

    # load and process ear image
    print(f'### Loading and processing ear picture from {args.ear_path}...')
    img = Image.open(args.ear_path, 'r').convert('RGB')
    transforms = Compose([
        Resize(img_size),
        ToTensor(),
        Grayscale(img_channels)
    ])
    ear = transforms(img)
    ear = ear.unsqueeze(0)
    print('### Done loading and processing.')

    # calculate elevation range
    el_range = cfg['el_range']
    if el_range:
        el_range = create_range(el_range)
    # calculate azimuth range
    az_range = cfg['az_range']
    if az_range:
        az_range = create_range(az_range)
    # create c tensor
    if el_range is not None and az_range is not None:
        c = torch.cartesian_prod(el_range, az_range)
    elif el_range is not None:
        c = el_range.unsqueeze(-1)
    elif az_range is not None:
        c = az_range.unsqueeze(-1)

    # predict datapoints
    print('### Predicting data...')
    ear, c = ear.to(args.device), c.to(args.device)
    with torch.no_grad():
        # ear to z_ear
        _, z_ear, *_ = models['ears'](ear)
        z_ears = z_ear.repeat(c.shape[0], 1)
        # z_ear + c to z_hrtf
        x = torch.cat((z_ears, c), dim=-1)
        z_hrtf = models['latent'](x)
        # z_hrtf to hrtf
        hrtf = models['hrtf'].cvae.dec(z_hrtf, c)
    hrtf = hrtf.cpu().numpy()
    c = c.cpu().numpy()
    print(f'### Done predicting. Data shape: {hrtf.shape}')

    # generate figure
    if args.view:
        print('### Generating figure...')
        output_path_resps = os.path.splitext(args.output_path)[0] + '_resps.png'
        output_path_surf = os.path.splitext(args.output_path)[0] + '_surf.png'
        f = rfftfreq(args.nfft, d=1. / args.sr)
        # make first figure (individual responses)
        n_cols = 6
        n_rows = math.ceil(len(c) / n_cols)
        figsize = n_cols * 4, n_rows * 2.4
        fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize)
        for i, ax in enumerate(axs.flatten()):
            if i < len(c):
                ax.plot(f, hrtf[i])
                ax.set_title(f'{c[i]}')
            else:
                ax.axis('off')
        fig.suptitle('PRTFs')
        fig.tight_layout(rect=[0, 0, 1, 0.98])
        fig.savefig(output_path_resps)
        # make second figure (surface plots)
        if c.shape[1] > 1:
            n_planes = len(az_range)
            az_range = az_range.numpy()
        else:
            n_planes = 1
            az_range = [0]
        fig, axs = plt.subplots(n_planes, 1, figsize=(12, 10 * n_planes), squeeze=False)
        extent = [f[0], f[-1], el_range[-1], el_range[0]]
        for i, az in enumerate(az_range):
            if c.shape[1] > 1:
                curr_hrtf = hrtf[c[:, 1] == az]
            else:
                curr_hrtf = hrtf
            plot_surface(fig, axs[i, 0], curr_hrtf, extent, az)
        fig.tight_layout()
        fig.savefig(output_path_surf)
        print(f'### Figure stored in {output_path_resps} and {output_path_surf}')

    # store
    print(f'### Storing data in {args.output_path}...')
    sio.savemat(args.output_path, {'synthesized_hrtf': hrtf, 'pos': c})
    print('### Done!')
예제 #22
0
파일: math_ops.py 프로젝트: malfet/pytorch
 def other_ops(self):
     a = torch.randn(4)
     b = torch.randn(4)
     c = torch.randint(0, 8, (5, ), dtype=torch.int64)
     e = torch.randn(4, 3)
     f = torch.randn(4, 4, 4)
     size = [0, 1]
     dims = [0, 1]
     return (
         torch.atleast_1d(a),
         torch.atleast_2d(a),
         torch.atleast_3d(a),
         torch.bincount(c),
         torch.block_diag(a),
         torch.broadcast_tensors(a),
         torch.broadcast_to(a, (4)),
         # torch.broadcast_shapes(a),
         torch.bucketize(a, b),
         torch.cartesian_prod(a),
         torch.cdist(e, e),
         torch.clone(a),
         torch.combinations(a),
         torch.corrcoef(a),
         # torch.cov(a),
         torch.cross(e, e),
         torch.cummax(a, 0),
         torch.cummin(a, 0),
         torch.cumprod(a, 0),
         torch.cumsum(a, 0),
         torch.diag(a),
         torch.diag_embed(a),
         torch.diagflat(a),
         torch.diagonal(e),
         torch.diff(a),
         torch.einsum("iii", f),
         torch.flatten(a),
         torch.flip(e, dims),
         torch.fliplr(e),
         torch.flipud(e),
         torch.kron(a, b),
         torch.rot90(e),
         torch.gcd(c, c),
         torch.histc(a),
         torch.histogram(a),
         torch.meshgrid(a),
         torch.lcm(c, c),
         torch.logcumsumexp(a, 0),
         torch.ravel(a),
         torch.renorm(e, 1, 0, 5),
         torch.repeat_interleave(c),
         torch.roll(a, 1, 0),
         torch.searchsorted(a, b),
         torch.tensordot(e, e),
         torch.trace(e),
         torch.tril(e),
         torch.tril_indices(3, 3),
         torch.triu(e),
         torch.triu_indices(3, 3),
         torch.vander(a),
         torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
         torch.view_as_complex(torch.randn(4, 2)),
         torch.resolve_conj(a),
         torch.resolve_neg(a),
     )
예제 #23
0
    def grid_data(self, positions: torch.Tensor, weights: torch.Tensor = None,
                  method: str = 'nearest') -> torch.Tensor:
        """Places data from positions onto grid using method='nearest'|'cic'
        where cic=cloud in cell. Returns gridded data and stores it as class attribute
        data.

        Note that for nearest we place the data at the nearest grid point. This corresponds to grid edges at
        the property cell_edges. For cic then if the particle lies on a grid point it is entirely places in that
        cell, otherwise it is split over multiple cells.
        """
        dimensions = tuple(self.n)
        fi = self._float_idx(positions)

        if weights is None:
            weights = positions.new_ones(positions.shape[0])

        if method == 'nearest':
            i = (fi + 0.5).type(torch.int64)
            gd = ((i >= 0) & (i < self.n[None, :])).all(dim=1)
            if gd.sum() == 0:
                return positions.new_zeros(dimensions)
            data = torch.sparse.FloatTensor(i[gd].t(), weights[gd],
                                            size=dimensions).to_dense().reshape(
                dimensions).type(
                dtype=positions.dtype)
        elif method == 'cic':

            dimensions = tuple(self.n + 2)
            i = fi.floor()
            offset = fi - i
            i = i.type(torch.int64) + 1

            gd = ((i >= 1) & (i <= self.n[None, :])).all(dim=1)
            if gd.sum() == 0:
                return positions.new_zeros(dimensions)
            weights, i, offset = weights[gd], i[gd], offset[gd]

            data = weights.new_zeros(dimensions)
            if len(self.n) == 1:
                # 1d is easier to handle as a special case
                indexes = torch.tensor([[0], [1]], device=i.device)
            else:
                twidle = torch.tensor([0, 1], device=i.device)
                indexes = torch.cartesian_prod(
                    *torch.split(twidle.repeat(len(self.n)), 2))
            for offsetdims in indexes:
                thisweights = torch.ones_like(weights)
                for dimi, offsetdim in enumerate(offsetdims):
                    if offsetdim == 0:
                        thisweights *= (torch.tensor(1.0) - offset[..., dimi])
                    if offsetdim == 1:
                        thisweights *= offset[..., dimi]
                data += torch.sparse.FloatTensor((i + offsetdims).t(),
                                                 thisweights * weights,
                                                 size=dimensions).to_dense().type(
                    dtype=positions.dtype)
            for dim in range(len(self.n)):
                data = data.narrow(dim, 1, data.shape[dim] - 2)
        else:
            raise ValueError(
                f'Method {method} not recognised. Allowed values are nearest|cic')

        return data
예제 #24
0
    def eval_step(self, batch, batch_idx, tag):
        inputs, g, rows = batch
        input_day, input_day_gov, y_gbm, y = inputs
        forecast_length = y.size()[-1]
        y_hat, state_gate, gov_gate = self.model(input_day, input_day_gov, g)
        if self.config.use_gbm:
            y_hat += y_gbm

        assert (y.size() == y_hat.size())

        if g['type'] == 'subgraph' and 'res_n_id' in g:  # if using SAINT sampler
            cent_n_id = g['cent_n_id']
            res_n_id = g['res_n_id']
            # Note: we only evaluate predictions on those initial nodes (per random walk)
            # to avoid duplicated computations
            y = y[:, res_n_id]
            y_hat = y_hat[:, res_n_id]
            cent_n_id = cent_n_id[res_n_id]
        else:
            cent_n_id = g['cent_n_id']
        if self.config.use_saintdataset:
            index_ptr = torch.cartesian_prod(torch.arange(rows.size(0)),
                                             torch.arange(cent_n_id.size(0)),
                                             torch.arange(forecast_length))

            label = pd.DataFrame({
                'row_idx':
                rows[index_ptr[:, 0]].data.cpu().numpy(),
                'node_idx':
                cent_n_id[index_ptr[:, 1]].data.cpu().numpy(),
                'forecast_idx':
                index_ptr[:, 2].data.cpu().numpy(),
                'val':
                y.flatten().data.cpu().numpy()
            })

            pred = pd.DataFrame({
                'row_idx':
                rows[index_ptr[:, 0]].data.cpu().numpy(),
                'node_idx':
                cent_n_id[index_ptr[:, 1]].data.cpu().numpy(),
                'forecast_idx':
                index_ptr[:, 2].data.cpu().numpy(),
                'val':
                y_hat.flatten().data.cpu().numpy()
            })
        else:
            index_ptr = torch.cartesian_prod(torch.arange(rows.size(0)),
                                             torch.arange(forecast_length))

            label = pd.DataFrame({
                'row_idx':
                rows[index_ptr[:, 0]].data.cpu().numpy(),
                'node_idx':
                cent_n_id[index_ptr[:, 0]].data.cpu().numpy(),
                'forecast_idx':
                index_ptr[:, 1].data.cpu().numpy(),
                'val':
                y.flatten().data.cpu().numpy()
            })

            pred = pd.DataFrame({
                'row_idx':
                rows[index_ptr[:, 0]].data.cpu().numpy(),
                'node_idx':
                cent_n_id[index_ptr[:, 0]].data.cpu().numpy(),
                'forecast_idx':
                index_ptr[:, 1].data.cpu().numpy(),
                'val':
                y_hat.flatten().data.cpu().numpy()
            })

        pred = pred.groupby(['row_idx', 'node_idx', 'forecast_idx']).mean()
        label = label.groupby(['row_idx', 'node_idx', 'forecast_idx']).mean()

        return {
            'label': label,
            'pred': pred,
            'info': [state_gate, gov_gate]
            # 'atten': atten_context
        }
예제 #25
0
def train_and_eval(train_data,
                   val_data,
                   test_data,
                   device='cuda',
                   model_class=MF,
                   model_args: dict = {
                       'emb_dim': 64,
                       'learning_rate': 0.05,
                       'weight_decay': 0.05
                   },
                   training_args: dict = {
                       'batch_size': 1024,
                       'epochs': 100,
                       'patience': 20,
                       'block_batch': [1000, 100]
                   }):

    # build data_loader.
    train_loader = utils.data_loader.Block(
        train_data,
        u_batch_size=training_args['block_batch'][0],
        i_batch_size=training_args['block_batch'][1],
        device=device)
    val_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(val_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)
    test_loader = utils.data_loader.DataLoader(
        utils.data_loader.Interactions(test_data),
        batch_size=training_args['batch_size'],
        shuffle=False,
        num_workers=0)

    # data shape
    n_user, n_item = train_data.shape

    # model and its optimizer.
    model = MF(n_user, n_item, dim=model_args['emb_dim'], dropout=0).to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=model_args['learning_rate'],
                                weight_decay=0)

    # loss_criterion
    criterion = nn.MSELoss(reduction='sum')

    def complement(u, i, u_all, i_all):
        mask_u = np.isin(u_all.cpu().numpy(), u.cpu().numpy())
        mask_i = np.isin(i_all.cpu().numpy(), i.cpu().numpy())
        mask = torch.tensor(1 - mask_u * mask_i).to('cuda')
        return mask

    # begin training
    stopping_args = Stop_args(patience=training_args['patience'],
                              max_epochs=training_args['epochs'])
    early_stopping = EarlyStopping(model, **stopping_args)
    for epo in range(early_stopping.max_epochs):
        training_loss = 0
        for u_batch_idx, users in enumerate(train_loader.User_loader):
            for i_batch_idx, items in enumerate(train_loader.Item_loader):
                # loss of training set
                model.train()
                users_train, items_train, y_train = train_loader.get_batch(
                    users, items)
                y_hat_obs = model(users_train, items_train)
                loss_obs = criterion(y_hat_obs, y_train)

                all_pair = torch.cartesian_prod(users, items)
                users_all, items_all = all_pair[:, 0], all_pair[:, 1]
                y_hat_all = model(users_all, items_all)
                impu_all = torch.zeros((users_all.shape)).to(device) - 1
                # mask = complement(users_train, items_train, users_all, items_all)
                # loss_all = criterion(y_hat_all * mask, impu_all * mask)
                loss_all = criterion(y_hat_all, impu_all)

                loss = loss_obs + model_args[
                    'imputaion_lambda'] * loss_all + model_args[
                        'weight_decay'] * model.l2_norm(users_all, items_all)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                training_loss += loss.item()

        model.eval()
        with torch.no_grad():
            # train metrics
            train_pre_ratings = torch.empty(0).to(device)
            train_ratings = torch.empty(0).to(device)
            for u_batch_idx, users in enumerate(train_loader.User_loader):
                for i_batch_idx, items in enumerate(train_loader.Item_loader):
                    users_train, items_train, y_train = train_loader.get_batch(
                        users, items)
                    pre_ratings = model(users_train, items_train)
                    train_pre_ratings = torch.cat(
                        (train_pre_ratings, pre_ratings))
                    train_ratings = torch.cat((train_ratings, y_train))

            # validation metrics
            val_pre_ratings = torch.empty(0).to(device)
            val_ratings = torch.empty(0).to(device)
            for batch_idx, (users, items, ratings) in enumerate(val_loader):
                pre_ratings = model(users, items)
                val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
                val_ratings = torch.cat((val_ratings, ratings))

        train_results = utils.metrics.evaluate(train_pre_ratings,
                                               train_ratings, ['MSE', 'NLL'])
        val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                             ['MSE', 'NLL', 'AUC'])

        print('Epoch: {0:2d} / {1}, Traning: {2}, Validation: {3}'.format(
            epo, training_args['epochs'], ' '.join([
                key + ':' + '%.3f' % train_results[key]
                for key in train_results
            ]), ' '.join([
                key + ':' + '%.3f' % val_results[key] for key in val_results
            ])))

        if early_stopping.check([val_results['AUC']], epo):
            break

    # testing loss
    print('Loading {}th epoch'.format(early_stopping.best_epoch))
    model.load_state_dict(early_stopping.best_state)

    # validation metrics
    val_pre_ratings = torch.empty(0).to(device)
    val_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(val_loader):
        pre_ratings = model(users, items)
        val_pre_ratings = torch.cat((val_pre_ratings, pre_ratings))
        val_ratings = torch.cat((val_ratings, ratings))

    # test metrics
    test_users = torch.empty(0, dtype=torch.int64).to(device)
    test_items = torch.empty(0, dtype=torch.int64).to(device)
    test_pre_ratings = torch.empty(0).to(device)
    test_ratings = torch.empty(0).to(device)
    for batch_idx, (users, items, ratings) in enumerate(test_loader):
        pre_ratings = model(users, items)
        test_users = torch.cat((test_users, users))
        test_items = torch.cat((test_items, items))
        test_pre_ratings = torch.cat((test_pre_ratings, pre_ratings))
        test_ratings = torch.cat((test_ratings, ratings))

    val_results = utils.metrics.evaluate(val_pre_ratings, val_ratings,
                                         ['MSE', 'NLL', 'AUC'])
    test_results = utils.metrics.evaluate(
        test_pre_ratings,
        test_ratings, ['MSE', 'NLL', 'AUC', 'Recall_Precision_NDCG@'],
        users=test_users,
        items=test_items)
    print('-' * 30)
    print('The performance of validation set: {}'.format(' '.join(
        [key + ':' + '%.3f' % val_results[key] for key in val_results])))
    print('The performance of testing set: {}'.format(' '.join(
        [key + ':' + '%.3f' % test_results[key] for key in test_results])))
    print('-' * 30)
    return val_results, test_results
예제 #26
0
import torch

import models
import math
import matplotlib.pyplot as plt

if __name__ == '__main__':
    filepath = "MathNet_3hidden.pt"
    ann = models.MathNet(2, 15, 1).cuda()

    x = torch.cartesian_prod(torch.tensor([1.]), torch.linspace(-10, 10,
                                                                100)).cuda()
    y = torch.unsqueeze(torch.sin(x[:, 0] + torch.div(x[:, 1], math.pi)),
                        dim=1).cuda()

    ann.load_state_dict(torch.load(filepath))
    ann.eval()

    pred = ann(x)
    plt.plot(x[:, 1].tolist(), y.tolist())
    plt.plot(x[:, 1].tolist(), pred.tolist())
    plt.show()
예제 #27
0
def get_contrastive_pairs(scores: Tensor, labels: Tensor) -> Tensor:
    positive_scores = scores[mask_to_index_1d(labels == 1)].view(-1)
    negative_scores = scores[mask_to_index_1d(labels == 0)].view(-1)
    score_pairs = torch.cartesian_prod(positive_scores, negative_scores)
    return score_pairs
예제 #28
0
def compute_shifts(cell, pbc, cutoff):
    """Compute the shifts of unit cell along the given cell vectors to make it
    large enough to contain all pairs of neighbor atoms with PBC under
    consideration

    Arguments:
        cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
        vectors defining unit cell:
            tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
        cutoff (float): the cutoff inside which atoms are considered pairs
        pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
            if pbc is enabled for that direction.

    Returns:
        :class:`torch.Tensor`: long tensor of shifts. the center cell and
            symmetric cells are not included.
    """
    # type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
    reciprocal_cell = cell.inverse().t()
    inv_distances = reciprocal_cell.norm(2, -1)
    num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
    num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
    r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
    r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
    r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
    o = torch.zeros(1, dtype=torch.long, device=cell.device)
    return torch.cat([
        torch.cartesian_prod(r1, r2, r3),
        torch.cartesian_prod(r1, r2, o),
        torch.cartesian_prod(r1, r2, -r3),
        torch.cartesian_prod(r1, o, r3),
        torch.cartesian_prod(r1, o, o),
        torch.cartesian_prod(r1, o, -r3),
        torch.cartesian_prod(r1, -r2, r3),
        torch.cartesian_prod(r1, -r2, o),
        torch.cartesian_prod(r1, -r2, -r3),
        torch.cartesian_prod(o, r2, r3),
        torch.cartesian_prod(o, r2, o),
        torch.cartesian_prod(o, r2, -r3),
        torch.cartesian_prod(o, o, r3),
    ])
예제 #29
0
def analyze(point_file_path, max_resolution_file_path):
    (voxel_grid, voxel_size, offset) = load_voxel_grid(point_file_path)

    max_num_fitted_models = 10
    use_cuboids = True
    use_spheres = True
    use_capsules = True
    loss_type = LossType.BEST_MATCH
    visualize_intermediate = False
    use_cuda = False

    start_time = time.perf_counter()
    fitted_models = fit.fit_voxel_grid(
        voxel_grid,
        max_num_fitted_models=max_num_fitted_models,
        use_cuboid=use_cuboids,
        use_sphere=use_spheres,
        use_capsule=use_capsules,
        loss_type=loss_type,
        visualize_intermediate=visualize_intermediate,
        use_cuda=use_cuda)
    end_time = time.perf_counter()
    duration = end_time - start_time
    """
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    draw.draw_voxels(ax, voxel_grid)
    for m in fitted_models:
        m.draw(ax)
    plt.show()
    """

    (max_voxel_grid, max_voxel_size,
     max_offset) = load_voxel_grid(max_resolution_file_path)

    padding0 = math.ceil(max_voxel_grid.shape[0] / 2)
    padding1 = math.ceil(max_voxel_grid.shape[1] / 2)
    padding2 = math.ceil(max_voxel_grid.shape[2] / 2)

    # Add some padding so that any fitted primitives outside of the normal grid
    # will be taken into account
    max_voxel_grid = nn.functional.pad(
        max_voxel_grid,
        [padding2, padding2, padding1, padding1, padding0, padding0])

    for m in fitted_models:
        m.uniform_scale(voxel_size / max_voxel_size)
        m.translate(
            torch.tensor([padding0, padding1, padding2], dtype=torch.float))

    max_indices = torch.cartesian_prod(
        torch.arange(0, max_voxel_grid.shape[0]),
        torch.arange(0, max_voxel_grid.shape[1]),
        torch.arange(0, max_voxel_grid.shape[2]))
    max_indices_float = max_indices.float()

    covered_by_models = torch.zeros_like(max_voxel_grid, dtype=torch.bool)

    for m in fitted_models:
        covered = max_indices[m.exact_containment(max_indices_float)]
        voxel.batch_set(covered_by_models, covered, True)
    """
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    draw.draw_voxels(ax, covered_by_models)
    for m in fitted_models:
        m.draw(ax)
    plt.show()

    fig = plt.figure()
    ax = fig.gca(projection='3d')
    draw.draw_voxels(ax, max_voxel_grid)
    for m in fitted_models:
        m.draw(ax)
    plt.show()
    """

    print(fitted_models)

    overall_jaccard_index = float(
        (max_voxel_grid & covered_by_models).sum()) / float(
            (max_voxel_grid | covered_by_models).sum())
    print("Overall Jaccard Index: " + str(overall_jaccard_index))
    print("Done")
    return (overall_jaccard_index, duration, int(voxel_grid.sum()))
예제 #30
0
def get_relative_distances(window_size):
    indices = torch.cartesian_prod(torch.arange(7), torch.arange(7))
    distances = indices[None, :, :] - indices[:, None, :]

    return distances