Exemplo n.º 1
    def create_spherical_dataset(self,
        Creates k=4 sperical clusters in 3D space along the space-diagonal

        num_samples_cluster: int
            Number of samples per cluster. Each process will create n // MPI_WORLD.size elements for each cluster
        radius: float
            Radius of the sphere
        offset: float
            Shift of the clusters along the axes. The 4 clusters will be positioned centered around c1=(offset, offset,offset),
            c2=(2*offset,2*offset,2*offset), c3=(-offset, -offset, -offset) and c4=(2*offset, -2*offset, -2*offset)
        dtype: ht.datatype
        random_state: int
            seed of the torch random number generator
        # contains num_samples

        p = ht.MPI_WORLD.size
        # create k sperical clusters with each n elements per cluster. Each process creates k * n/p elements
        num_ele = num_samples_cluster // p
        # radius between 0 and 1
        r = ht.random.rand(num_ele, split=0) * radius
        # theta between 0 and pi
        theta = ht.random.rand(num_ele, split=0) * 3.1415
        # phi between 0 and 2pi
        phi = ht.random.rand(num_ele, split=0) * 2 * 3.1415
        # Cartesian coordinates
        x = r * ht.sin(theta) * ht.cos(phi)
        x.astype(dtype, copy=False)
        y = r * ht.sin(theta) * ht.sin(phi)
        y.astype(dtype, copy=False)
        z = r * ht.cos(theta)
        z.astype(dtype, copy=False)

        cluster1 = ht.stack((x + offset, y + offset, z + offset), axis=1)
        cluster2 = ht.stack((x + 2 * offset, y + 2 * offset, z + 2 * offset),
        cluster3 = ht.stack((x - offset, y - offset, z - offset), axis=1)
        cluster4 = ht.stack((x - 2 * offset, y - 2 * offset, z - 2 * offset),

        data = ht.concatenate((cluster1, cluster2, cluster3, cluster4), axis=0)
        # Note: enhance when shuffle is available
        return data
Exemplo n.º 2
    def label_to_one_hot(a):
        max_label = ht.max(a)
        a = a.expand_dims(1)

        items = ht.arange(0, max_label.item() + 1)
        one_hot = ht.stack([items for i in range(a.shape[0])], axis=0)
        one_hot = ht.where(one_hot == a, 1, 0)

        return one_hot