Esempio n. 1
0
def mcxplotvol(data, surface_c=50, logplot=True):
    import plotly.graph_objects as go
    import plotly.io as pio
    pio.renderers.default = "browser"

    import numpy as np

    if logplot:
        data = np.log10(np.array(data))
    else:
        data = np.array(data)

    datashape = data.shape

    X, Y, Z = np.mgrid[0:1:datashape[0] * 1j, 0:1:datashape[1] * 1j,
                       0:1:datashape[2] * 1j]

    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=data[:, :, :, 0].flatten(),
        #			isomin=0.01,
        #			isomax= data.max(),
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=
        surface_c,  # number of isosurfaces, 2 by default: only min and max
    ))

    fig.show()
Esempio n. 2
0
def run():
    space = numpy.load(TEST_RESULT_PATH, allow_pickle=True)
    X, Y, Z = numpy.mgrid[-8:8:31200000j, -8:8:31200000j, -8:8:31200000j]
    values = numpy.sin(X * Y * Z) / (X * Y * Z)
    print(space.flatten().shape)


    # z, x, y, c = space.nonzero()
    #
    # colors = []
    # for Z, X, Y in zip(z, x, y):
    #     colors.append(space[Z, X, Y, 0]/255)

    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=space.flatten()/255,
        isomin=0.1,
        isomax=0.8,
        opacity=0.5,  # needs to be small to see through all surfaces
        surface_count=17,  # needs to be a large number for good volume rendering
        ))

    fig.show()
Esempio n. 3
0
def plot_ddm_dct_volume(ddm_dct: np.ndarray, wavenumber_factor,
                        wavenumber_unit, freq_factor):
    orgin_zero = ddm_dct.copy()
    orgin_zero[0, 0, :] = 0
    # min_signal = np.amin(orgin_zero[:, :, 1:])
    max_signal = np.amax(orgin_zero[:, :, 1:])

    qx_len = ddm_dct.shape[1]
    qy_len = ddm_dct.shape[0] // 2
    freq_len = ddm_dct.shape[2]
    qx_axis = np.arange(0, qx_len)
    qy_axis, qx_axis, freq_axis = np.mgrid[-qy_len * wavenumber_factor:qy_len *
                                           wavenumber_factor:ddm_dct.shape[0] *
                                           1j, 0:(qx_len - 1) *
                                           wavenumber_factor:qx_len * 1j,
                                           0:(freq_len - 1) *
                                           freq_factor:freq_len * 1j]

    ddm_dct_shift = np.fft.fftshift(ddm_dct, axes=0)

    vol = go.Volume(
        x=qy_axis.flatten(),
        y=qx_axis.flatten(),
        z=freq_axis.flatten(),
        value=ddm_dct_shift.flatten(),
        isomin=-max_signal,
        isomax=max_signal,
        opacity=0.2,
        surface_count=20,
        opacityscale=[[-max_signal, 1], [-0.05, 0], [0.05, 0],
                      [+max_signal, 1]],
        caps=dict(x_show=False, y_show=False, z_show=False),
    )
    fig = go.Figure(vol)
    return fig
def plot_decode_cop_voxel(base_cop, plot_file_name):

    import plotly.graph_objects as go
    import plotly as py
    import plotly.express as px
    X, Y, Z = np.mgrid[0:len(base_cop), 0:len(base_cop), 0:len(base_cop)]
    #input_conv_data[0,:,:,:,0]=0.2
    values_cop = base_cop.flatten()

    from sklearn.preprocessing import MinMaxScaler
    scaler = MinMaxScaler()
    scaled_values = scaler.fit_transform(values_cop.reshape(-1, 1))
    trace1 = go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=scaled_values[:, 0],
        isomin=0,
        isomax=1,
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=17,  # needs to be a large number for good volume rendering
        colorscale='Greens')

    layout = go.Layout(margin=dict(l=0, r=0, b=0, t=0))

    data = [trace1]

    fig = go.Figure(data=data, layout=layout)
    py.offline.plot(fig, filename=plot_file_name)
Esempio n. 5
0
def soft_ellipse(fig, mu, rad, approx, rsize=3):

    amin = -rsize
    amax = rsize
    X, Y, Z = np.mgrid[amin:amax:50j, amin:amax:50j, amin:amax:50j]
    mu_X = mu[0]
    mu_Y = mu[1]
    mu_Z = mu[2]
    rad_X = np.power(rad[0], 2.0)
    rad_Y = np.power(rad[1], 2.0)
    rad_Z = np.power(rad[2], 2.0)

    x = np.power(mu_X - X, 2.0) / rad_X
    y = np.power(mu_Y - Y, 2.0) / rad_Y
    z = np.power(mu_Z - Z, 2.0) / rad_Z

    values = np.sqrt(x + y + z)

    values2 = unitboxcar(values, 0.0, 2.0, approx)

    #fig = go.Figure(data=go.Volume(
    fig.add_trace(
        go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=values2.flatten(),
            isomin=0.01,
            isomax=1.0,
            opacity=0.05,  # needs to be small to see through all surfaces
            surface_count=
            25,  # needs to be a large number for good volume rendering
        ))
Esempio n. 6
0
def _generate_volume_data(points, field, **kwargs):
    """Generates volume data plot for plotly

    Args:
        points (Point_Array3): coordinates
        field (Field3): Magnetic field vector

    Returns:
        dict: plotly volume data structure
    """

    cmin = kwargs.pop("cmin", 0.0)
    cmax = kwargs.pop("cmax", 0.5)
    colorscale = kwargs.pop("colorscale", "viridis")

    opacityscale = kwargs.pop("opacityscale", None)

    caps = kwargs.pop("no_caps", False)
    if caps:
        caps = dict(x_show=False, y_show=False, z_show=False)
    else:
        caps = dict(x_show=True, y_show=True, z_show=True)

    if type(opacityscale) is str:
        if opacityscale.lower() == "normal":
            opacityscale = [
                [cmin, 0],
                [(cmax - cmin) / 4, 0.5],
                [0.2, 0],
                [cmax, 1],
            ]
        elif opacityscale.lower() == "invert":
            opacityscale = [
                [cmin, 1],
                [(cmax - cmin) * 3 / 4, 0.5],
                [0.2, 0],
                [cmax, 0],
            ]

    return _go.Volume(
        x=points.x.flatten(),
        y=points.y.flatten(),
        z=points.z.flatten(),
        value=field.n.flatten(),
        colorscale=colorscale,
        cmin=cmin,
        cmax=cmax,
        isomin=cmin,
        isomax=cmax,
        # opacity needs to be small to see through all surfaces
        opacity=kwargs.pop("opacity", 0.1),
        opacityscale=opacityscale,
        surface_count=kwargs.pop("num_levels", 10),
        showscale=True,
        caps=caps,
        colorbar=dict(title="|B| (" + field.unit + ")"),
    )
Esempio n. 7
0
def printevent(event, counter, outdir=outdir):
    data = []
    vardata = []
    for pi in points:
        pdata = []
        pvar = []
        for ci in coords:
            d = df['pre_selection_add_stage_0_att_gn1_coord_add_mean_' +
                   str(pi) + '_' + str(ci)]
            pdata.append(d[event])
            pvar.append(df['pre_selection_add_stage_0_att_gn1_coord_add_var_' +
                           str(pi) + '_' + str(ci)][event])
        data.append(pdata)
        vardata.append(pvar)
    data = np.array(data)
    vardata = np.array(vardata)

    def trfdata(x):
        x = np.transpose(x, [1, 0])
        return np.expand_dims(x, axis=(0, 1, 2))

    #process to plot
    data = trfdata(data)
    vardata = trfdata(vardata)

    vol = np.exp(-3. * (data - mgrid)**2 / vardata)
    #print(vol.shape)
    vol = np.prod(vol, axis=3)  #the x**2 axis
    vol = np.sum(vol, axis=-1)  #the points axis
    #insert data here. make data span a function for mesh grid
    #pts = (l * np.random.rand(3, 15)).astype(np.int)
    #vol[tuple(indices for indices in pts)] = 1

    #from scipy import ndimage
    #vol = ndimage.gaussian_filter(vol, 1)
    vol /= vol.max()

    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=vol.flatten(),
        isomin=0.2,
        isomax=0.7,
        opacity=0.1,
        surface_count=25,
    ))
    #fig.update_layout(scene_xaxis_showticklabels=False,
    #                  scene_yaxis_showticklabels=False,
    #                  scene_zaxis_showticklabels=False)
    #
    #fig.show(renderer='chrome')
    if counter >= 0:
        fig.write_image(outdir + '/' + str(counter).zfill(10) + '.png')
    else:
        fig.write_html(outdir + '/last.html')
Esempio n. 8
0
    def _plot_velocity_volumetric(self, Xqs, yqs, fig, row, col, plot_args=None):
        """
        # generic method for any plot
        :param Xqs: filtered Nx3 position
        :param yqs:  filtered N values
        :param fig:
        :param row:
        :param col:print("Number of points after filtering: ", Xq_mv.shape[0])
        :param plot_args: symbol, size, opacity, cbar_x_pos, cbar_min, cbar_max
        """

        print(" Plotting row {}, col {}".format(row, col))
        fname_in = "./datasets/kyle_ransalu/5_airsim/5_airsim1/5_airsim1_vel_train_normalized_infilled"
        prefilled_X = pd.read_csv(fname_in + '.csv', delimiter=',').to_numpy()[:2542,1:4]
        mask = np.sum(euclidean_distances(Xqs, prefilled_X) <= 0.3, axis=1) >= 1
        # Xqs = torch.where((torch.ones_like(Xqs) * mask[:, None]).to(dtype=torch.bool), Xqs, torch.ones_like(Xqs) * -1000)
        yqs = torch.where((torch.ones_like(yqs) * mask[:, None]).to(dtype=torch.bool), yqs, torch.ones_like(yqs) * -1000)

        # marker and colorbar arguments
        if plot_args is None:
            symbol, size, opacity, cbar_x_pos, cbar_min, cbar_max = 'square', 8, 0.2, False, yqs[:,0].min(), yqs[:,0].max()
        else:
            symbol, size, opacity, cbar_x_pos, cbar_min, cbar_max = plot_args
        if cbar_x_pos is not False:
            colorbar = dict(x=cbar_x_pos,
                            len=1,
                            y=0.5
                        )
        else:
            colorbar = dict()

        colorbar["tickfont"] = dict(size=18)

        fig.add_trace(
            go.Volume(
                x=Xqs[:, 0],
                y=Xqs[:, 1],
                z=Xqs[:, 2],
                isomin=-7,
                isomax=7,
                value=yqs,
                opacity=0.05,
                surface_count=40,
                colorscale="Jet",
                opacityscale=[[0, 0], [self.surface_threshold[0], 0], [1, 1]],
                colorbar=colorbar,
                # cmax=1,
                # cmin=self.surface_threshold[0],
            ),
            row=1,
            col=2
        )
Esempio n. 9
0
def _volumn_data(image, threshold, color, opacity):
    image = image.copy().transpose(2, 1, 0)
    x, y, z = image.shape
    X, Y, Z = np.mgrid[:x, :y, :z]

    return go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=image.flatten(),
        colorscale=[[0, color], [1, color]],
        isomin=threshold,
        isomax=image.flatten().max(),
        opacity=opacity,
        surface_count=2,
    )
Esempio n. 10
0
def main():
    with open(INPUT, 'rb') as input:
        size_x, size_y, size_z = struct.unpack('<HHH', input.read(struct.calcsize('<HHH')))

        buffer = input.read(size_x * size_y * size_z * 2)
        volume = np.ndarray((size_z, size_y, size_x), np.dtype('<i2'), buffer=buffer)

    Z, Y, X = np.mgrid[0:size_z, 0:size_y, 0:size_x]

    fig = go.Figure(data=go.Volume(
        x=X.flatten(), y=Y.flatten(), z=Z.flatten(),
        value=volume.flatten(),
        isomin=0.0,
        isomax=2**10,
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=4  # needs to be a large number for good volume rendering
    ))
    fig.show()
Esempio n. 11
0
def plot_err_field(X, Y, Z, values, fig_name):
    values.min()
    values.max()
    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=values.flatten(),
        isomin=1.2 * values.min(),
        # isomin=0.90 * values.max(),
        isomax=0.95 * values.max(),
        # isomin=-0.1,
        # isomax=0.8,
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=
        21,  # 21 needs to be a large number for good volume rendering
    ))
    fig.show()
def act_map_3d(volume):
    '''
    Function used to plot a 3D activation map
    Inputs: 3D image
    Output: figure with the activation map
    '''

    ## Resize image (as to plot the original 3D image will take a high computational time)

    # Define desired shape
    input_shape = (55, 65, 40)

    # Compute factors
    height = volume.shape[0] / input_shape[0]
    width = volume.shape[1] / input_shape[1]
    depth = volume.shape[2] / input_shape[2]

    height_factor = 1 / height
    width_factor = 1 / width
    depth_factor = 1 / depth

    # Resize across z-axis
    resized_volume = ndimage.zoom(volume,
                                  (height_factor, width_factor, depth_factor),
                                  order=1)

    # Get minimum and maximum pixel values of the volume
    min_value = np.amin(resized_volume)
    max_value = np.amax(resized_volume)

    X, Y, Z = np.mgrid[0:55:55j, 0:65:65j, 0:40:40j]

    fig = go.Figure(data=go.Volume(x=X.flatten(),
                                   y=Y.flatten(),
                                   z=Z.flatten(),
                                   value=resized_volume.flatten(),
                                   isomin=min_value,
                                   isomax=max_value,
                                   colorscale="jet",
                                   opacity=0.1,
                                   surface_count=17))
    fig.show()
def plot3d(img_3d):
    l = 30
    l = img_3d.shape[1]
    vol = np.zeros((l, l, l))
    pts = (l * np.random.rand(3, 15)).astype(np.int)
    vol[tuple(indices for indices in pts)] = 1
    from scipy import ndimage
    vol = ndimage.gaussian_filter(vol, 4)
    vol /= vol.max()
    img_3d = vol

    X, Y, Z = np.mgrid[:l, :l, :l]
    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=img_3d.flatten(),
        isomin=0.2,
        isomax=0.7,
        opacity=0.1,
        surface_count=25,
    ))
    fig.show()
Esempio n. 14
0
    def generate_3D_plot(self,
                         col_index,
                         min_value,
                         max_value,
                         opacity=1.0,
                         surface_count=100):
        # This function is based on plotly which is currently not imported, so
        # this function is not functional.
        active_data = self.__decode_image_index(
            active_data=self.monochrome_image)

        coords = active_data.index.to_frame()

        fig = go.Figure()
        fig.add_trace(
            go.Volume(x=coords['x_values'],
                      y=coords['y_values'],
                      z=coords['z_values'],
                      value=active_data.loc[:, col_index],
                      isomin=min_value,
                      isomax=max_value,
                      opacity=opacity,
                      surface_count=surface_count))
        plot(fig)
Esempio n. 15
0
#%% plot 3d grid and volume visualisation
# print(np.mgrid[-1:2:10j])  #np.mgrid used to generate vertical meshgrid

print("hello world started")

import plotly.graph_objects as go
import numpy as np
X, Y, Z = np.mgrid[-3:3:3j, 10:20:4j, 8:14:4j]
values = np.sin(X * Y * Z) / (X * Y * Z)

fig = go.Figure(data = go.Volume(
    x = X.flatten(),
    y = Y.flatten(),
    z = Z.flatten(),
    value = values.flatten(),
    isomin = .1,
    isomax = .8,
    opacity = 0.1, # need to be small to see all those surfaces
    surface_count = 17, # needs to be large to render better
))
fig.show()
x = X.flatten()
y = Y.flatten()
z = Z.flatten()
value = values.flatten()

print(X)
print(Y)
print(Z)
print("hello world finished")
Esempio n. 16
0
fig.update_layout(scene_camera=camera, title=name)
fig.show()

# In[ ]:

import plotly.graph_objects as go
import numpy as np
X, Y, Z = np.mgrid[-8:8:40j, -8:8:40j, -8:8:40j]
values = np.sin(X * Y * Z) / (X * Y * Z)

fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=values.flatten(),
    isomin=0.1,
    isomax=0.8,
    opacity=0.1,  # needs to be small to see through all surfaces
    surface_count=17,  # needs to be a large number for good volume rendering
))
fig.show()

# In[ ]:

import plotly.graph_objects as go
import numpy as np
X, Y, Z = np.mgrid[-1:1:30j, -1:1:30j, -1:1:30j]
values = np.sin(np.pi * X) * np.cos(np.pi * Z) * np.sin(np.pi * Y)

fig = go.Figure(data=go.Volume(
    x=X.flatten(),
Y = []
Z = []

for l1 in L1:
    for l2 in L2:
        for l3 in L3:
            X.append(l1)
            Y.append(l2)
            Z.append(l3)

            arrPrecision.append(precisionAngles([l1, l2, l3]))

fig = go.Figure(data=go.Volume(
    x=X,
    y=Y,
    z=Z,
    value=arrPrecision,
    opacity=0.3,
    surface_count=18,
))

# fig.update_layout(
#     title="Evolution of max distance Y",
#     xaxis_title="L1 in mm",
#     yaxis_title="L2 in mm",
#     zaxis_title="L3 in mm",
#     font=dict(
#         family="Courier New, monospace"
#     ),
#     height=600,
#     width=600
# )
Esempio n. 18
0
def visualize_distributions(environment_path, distributions_path, name):
    """
    Function to visualize a set of distributions
    :param environment_path: Path to the JSON file containing information on the environment,
    in which the distributions will be placed
    :param distributions_path: Path to the JSON file containing the specifications of the distribution set
    :param name: Name of the distribution set
    :return:
    """

    with open(distributions_path, "r+") as dist_file:
        distributions_loaded = json.load(dist_file)

    with open(environment_path, "r+") as env_file:
        environment = json.load(env_file)

    length = environment["params"]["length"]
    width = environment["params"]["width"]
    height = environment["params"]["height"]

    distributions = [[((x[0] + x[1]) / 2) for x in val]
                     for val in distributions_loaded.values()]

    x, y, z = np.mgrid[0:length:2, 0:width:2, 0:height:2]
    pos = np.empty((x.shape[0] * x.shape[1] * x.shape[2], 3))
    pos[:, 0] = x.flatten()
    pos[:, 1] = y.flatten()
    pos[:, 2] = z.flatten()

    no_distributions = len(distributions)

    if no_distributions < 5:
        rows = 1
        cols = no_distributions
    else:
        rows = math.ceil(no_distributions / 4)
        cols = 4

    specs = [[{'type': 'scatter3d'} for _ in range(cols)] for _ in range(rows)]

    fig = make_subplots(rows=rows,
                        cols=cols,
                        subplot_titles=[
                            "Class {}".format(x)
                            for x in range(1, no_distributions + 1)
                        ],
                        specs=specs,
                        vertical_spacing=0.05)

    for index, distribution in enumerate(distributions):
        col = (index % 4) + 1
        row = (index // 4) + 1
        dist = multivariate_normal(distribution[:3], np.diag(distribution[3:]))
        values = dist.pdf(pos)
        norm_values = ((values - min(values)) / (max(values) - min(values)))
        fig.add_trace(go.Volume(x=x.flatten(),
                                y=y.flatten(),
                                z=z.flatten(),
                                value=norm_values,
                                opacity=0.1,
                                surface_count=21),
                      row=row,
                      col=col)
    fig.update_layout(
        height=500 * rows,
        title="<b>Distributions '{}' in Environment '{}'</b>".format(
            name,
            environment_path.split("/")[-1].split(".")[0]),
        font=dict(family="Courier New, monospace", size=14, color="#7f7f7f"))
    fig.show()
Esempio n. 19
0
def create_volume(data=None,
                  x_labels=None,
                  y_labels=None,
                  z_labels=False,
                  trace_kwargs=None,
                  return_trace_idx=False,
                  row=None,
                  col=None,
                  scene_name='scene',
                  fig=None,
                  **layout_kwargs):
    """Create a volume plot.

    Args:
        data (array_like): Data in any format that can be converted to NumPy.

            Must be a 3-dim array.
        x_labels (array_like): X-axis labels.
        y_labels (array_like): Y-axis labels.
        z_labels (array_like): Z-axis labels.
        trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Volume`.
        return_trace_idx (bool): Whether to return trace index for `update_volume_data`.
        row (int): Row position.
        col (int): Column position.
        scene_name (str): Reference to the 3D scene.
        fig (plotly.graph_objects.Figure): Figure to add traces to.
        **layout_kwargs: Keyword arguments for layout.

    !!! note
        Figure widgets have currently problems displaying NaNs.
        Use `.show()` method for rendering.

    ## Example

    ```python-repl
    >>> import vectorbt as vbt
    >>> import numpy as np

    >>> vbt.plotting.create_volume(
    ...     data=np.random.randint(1, 10, size=(3, 3, 3)),
    ...     x_labels=['a', 'b', 'c'],
    ...     y_labels=['d', 'e', 'f'],
    ...     z_labels=['g', 'h', 'i']
    ... )
    ```

    ![](/vectorbt/docs/img/create_volume.png)
    """
    from vectorbt.settings import layout

    if trace_kwargs is None:
        trace_kwargs = {}
    if data is None:
        raise ValueError("Data must be passed")
    data = np.asarray(data)
    checks.assert_ndim(data, 3)
    if x_labels is None:
        x_labels = np.arange(data.shape[0])
    if y_labels is None:
        y_labels = np.arange(data.shape[1])
    if z_labels is None:
        z_labels = np.arange(data.shape[2])
    x_labels = np.asarray(x_labels)
    y_labels = np.asarray(y_labels)
    z_labels = np.asarray(z_labels)

    if fig is None:
        fig = CustomFigureWidget()
        if 'width' in layout:
            # Calculate nice width and height
            fig.update_layout(width=layout['width'],
                              height=0.7 * layout['width'])

    # Non-numeric data types are not supported by go.Volume, so use ticktext
    # Note: Currently plotly displays the entire tick array, in future versions it will be more sensible
    more_layout = dict()
    if not np.issubdtype(x_labels.dtype, np.number):
        x_ticktext = x_labels
        x_labels = np.arange(data.shape[0])
        more_layout[scene_name] = dict(xaxis=dict(
            ticktext=x_ticktext, tickvals=x_labels, tickmode='array'))
    if not np.issubdtype(y_labels.dtype, np.number):
        y_ticktext = y_labels
        y_labels = np.arange(data.shape[1])
        more_layout[scene_name] = dict(yaxis=dict(
            ticktext=y_ticktext, tickvals=y_labels, tickmode='array'))
    if not np.issubdtype(z_labels.dtype, np.number):
        z_ticktext = z_labels
        z_labels = np.arange(data.shape[2])
        more_layout[scene_name] = dict(zaxis=dict(
            ticktext=z_ticktext, tickvals=z_labels, tickmode='array'))
    fig.update_layout(**more_layout)
    fig.update_layout(**layout_kwargs)

    # Arrays must have the same length as the flattened data array
    x = np.repeat(x_labels, len(y_labels) * len(z_labels))
    y = np.tile(np.repeat(y_labels, len(z_labels)), len(x_labels))
    z = np.tile(z_labels, len(x_labels) * len(y_labels))

    volume = go.Volume(
        x=x,
        y=y,
        z=z,
        value=data.flatten(),
        opacity=0.2,
        surface_count=15,  # keep low for big data
        colorscale='Plasma')
    volume.update(**trace_kwargs)
    fig.add_trace(volume, row=row, col=col)
    trace_idx = len(fig.data) - 1
    if return_trace_idx:
        return fig, trace_idx
    return fig
Esempio n. 20
0
def main():
    slf = SLF("res3d.slf")
    X = np.tile(slf.MESHX, slf.NPLAN)
    Y = np.tile(slf.MESHY, slf.NPLAN)
    # values=np.arange(0,len(X))/len(X)
    # values=np.tile(values,slf.NPLAN)
    # indices=slf.getVarsIndexes(['ELEVATION Z',"PRIVE 1"])
    indices = slf.getVarsIndexes(['ELEVATION Z', "VELOCITY U"])
    r = slf.getVariablesAt(slf.NFRAME - 1, indices)
    Z = r[0]
    values = r[1]

    # X, Y, Z = np.mgrid[-0:1:5j, -5:5:40j, -5:5:40j]

    # i=np.where((X<11) & (X>9) & (Y<6) & (Y>4) )[0]
    i = np.where((X < 1) & (Y < 1) & (Z < 1))[0]
    # i=np.where((values>0)  )[0]

    X = X[i]
    Y = Y[i]
    Z = Z[i]
    values = values[i]
    # values=np.arange(0,len(X))/len(X)

    layout = go.Layout(scene=dict(aspectmode='data'))
    # fig = go.Figure(layout=layout,data=go.Scatter3d(
    # x=X,
    # y=Y,
    # z=Z,
    # mode='markers',
    # marker=dict(
    #     size=3,
    #     color=values,                # set color to an array/list of desired values
    #     # colorscale='RdBu',
    #     colorscale=[[0.0, "rgba(255,0,0,1.0)"],[0.5, "rgba(0,0,255,0.1)"], [0.9, "rgba(0,0,255,0.0)"],[1.0, "rgba(255,0,0,0.0)"]],
    #     # colorscale='Viridis',   # choose a colorscale
    #     # opacity=0.2
    # )
    # ))
    # X, Y, Z = np.mgrid[-5:5:40j, -5:5:40j, -5:5:40j]
    print(X.max())
    print(Y.max())
    print(Z.max())
    print(values.max())

    # values[:]=1.0
    # print(values)
    # nvalue=len(X)/slf.NPLAN
    # for i in range(5):
    #     values[int(i*5):int((i+1)*nvalue)]=i
    # print(X.shape)
    # print(X)

    fig = go.Figure(
        layout=layout,
        data=go.Volume(
            x=X,
            y=Y,
            z=Z,
            value=values,
            isomin=0,
            isomax=100,
            surface_count=10,
            # spaceframe=dict(show=True, fill=1),
            # surface=dict(show=True, fill=1,count=1),

            # caps=dict(x_show=True, y_show=True),
            # slices=dict(x_show=False, y_show=False),
            opacity=0.1,  # needs to be small to see through all surfaces
            # surface_count=17, # needs to be a large number for good volume rendering
        ))
    print((values - values.min()) / (values.max() - values.min()))
    # print(X,Y)
    # print(values)
    fig.write_html('plot/index.html')
Esempio n. 21
0
def visualize_segmentation_over_distribution(distributions_path,
                                             segmentation_path,
                                             environment_path):
    """
    Function to visualize a distribution set together with one segmentation
    :param distributions_path: Path to the JSON file containing the specifications of the distribution set
    :param segmentation_path: Path to a segmentation NPY file
    :param environment_path: Path to the JSON file containing the specifications of the environment
    :return:
    """
    with open(distributions_path, "r+") as dist_file:
        distributions_loaded = json.load(dist_file)

    with open(environment_path, "r+") as env_file:
        environment = json.load(env_file)

    seg = np.load(segmentation_path)

    length = environment["params"]["length"]
    width = environment["params"]["width"]
    height = environment["params"]["height"]

    distributions = [[((x[0] + x[1]) / 2) for x in val]
                     for val in distributions_loaded.values()]

    x, y, z = np.mgrid[0:length:2, 0:width:2, 0:height:2]
    pos = np.empty((x.shape[0] * x.shape[1] * x.shape[2], 3))
    pos[:, 0] = x.flatten()
    pos[:, 1] = y.flatten()
    pos[:, 2] = z.flatten()

    no_distributions = len(distributions)

    if no_distributions < 5:
        rows = 1
        cols = no_distributions
    else:
        rows = math.ceil(no_distributions / 4)
        cols = 4

    specs = [[{'type': 'scatter3d'} for _ in range(cols)] for _ in range(rows)]

    fig = make_subplots(rows=rows,
                        cols=cols,
                        subplot_titles=[
                            "Class {}".format(x)
                            for x in range(1, no_distributions + 1)
                        ],
                        specs=specs,
                        vertical_spacing=0.05)

    segmentation_trace = go.Scatter3d(x=seg[:, 3],
                                      y=seg[:, 4],
                                      z=seg[:, 5],
                                      mode="markers",
                                      marker=dict(size=5,
                                                  opacity=0.2,
                                                  colorscale='Viridis'))

    for index, distribution in enumerate(distributions):
        col = (index % 4) + 1
        row = (index // 4) + 1
        dist = multivariate_normal(distribution[:3], np.diag(distribution[3:]))
        values = dist.pdf(pos)
        norm_values = ((values - min(values)) / (max(values) - min(values)))
        fig.add_trace(go.Volume(
            x=x.flatten(),
            y=y.flatten(),
            z=z.flatten(),
            value=norm_values,
            opacity=0.1,
            surface_count=21,
        ),
                      row=row,
                      col=col)
        fig.add_trace(segmentation_trace, row=row, col=col)

    fig.update_layout(title="Segmentation over Distributions",
                      font=dict(family="Courier New, monospace",
                                size=14,
                                color="#7f7f7f"))
    fig.show()
Esempio n. 22
0
        for f_val in np.linspace(0, 2.5, 50):
            ok = 0
            for j in good_data:
                if abs(f_val - j[0]) < 0.01 and abs(
                        k_5_val - j[2]) < 0.01 and abs(A_val - j[1]) < 0.01:
                    data_set.append([f_val, A_val, k_5_val, j[3]])
                    ok = 1
                    break
            if ok == 0:
                data_set.append([f_val, A_val, k_5_val, min(chaos) - 1])
data_set = np.array(data_set)
fig = go.Figure(data=go.Volume(
    x=data_set[:, 0],
    y=data_set[:, 1],
    z=data_set[:, 2],
    value=data_set[:, 3],
    isomin=min(chaos),
    isomax=max(chaos),
    opacity=0.1,
    surface=dict(count=40),
))
fig.update_layout(title='0-1 chaos test',
                  scene=dict(
                      xaxis=dict(range=[0.2, 1.5], ),
                      xaxis_title='f',
                      yaxis_title='[A]',
                      zaxis_title='k_5',
                      yaxis=dict(range=[0, 1.2], ),
                      zaxis=dict(range=[0, 12], ),
                  ))
fig.show()
Esempio n. 23
0
data[7:15, 7:15, 7:15] = 1

shape = [1, *data.shape]
std = 15
distance = 20
stepsize = 1
flow = random_deformation_momentum(shape, std, distance, stepsize)
warped = dense_image_warp(data[None, ..., None], flow)

X, Y, Z = np.mgrid[0:20:1, 0:20:1, 0:20:1]

fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=data.flatten(),
    isomin=0.1,
    isomax=1.0,
    opacity=0.1,  # needs to be small to see through all surfaces
    surface_count=100,  # needs to be a large number for good volume rendering
))
fig.show()

fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=warped[0, :, :, :, 0].numpy().flatten(),
    isomin=0.1,
    isomax=1.0,
    opacity=0.1,  # needs to be small to see through all surfaces
    surface_count=100,  # needs to be a large number for good volume rendering
        fmap_eval_base = fmap_eval.mean(axis=4).squeeze()

        #fmap_eval_base[abs(fmap_eval_base)<0.025]=0
        print(fmap_eval_base.shape)
        X, Y, Z = np.mgrid[0:len(fmap_eval_base), 0:len(fmap_eval_base),
                           0:len(fmap_eval_base)]
        #input_conv_data[0,:,:,:,0]=0.2
        #values_cop = base_cop

        trace1 = go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=fmap_eval_base.flatten(),
            #isomin=np.amin(fmap_eval_base.flatten()),
            isomax=np.amax(fmap_eval_base.flatten()),
            isomin=0,
            #isomax=0.5,
            opacity=0.3,  # needs to be small to see through all surfaces
            surface_count=
            27,  # needs to be a large number for good volume rendering
            colorscale='matter')

        data = [trace1]

        layout = go.Layout(margin=dict(l=0, r=0, b=0, t=0))

        fig = go.Figure(data=data, layout=layout)
        plot_file_name = deploy_path + 'feature_map.html'
        py.offline.plot(fig, filename=plot_file_name)
Esempio n. 25
0
def create_volume(data=None, x_labels=None, y_labels=None, z_labels=False,
                  trace_kwargs={}, fig=None, **layout_kwargs):
    """Create a volume plot.

    Args:
        data (array_like): Data in any format that can be converted to NumPy.

            Must be a 3-dim array.
        x_labels (array_like): X-axis labels.
        y_labels (array_like): Y-axis labels.
        z_labels (array_like): Z-axis labels.
        trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Volume`.
        fig (plotly.graph_objects.Figure): Figure to add traces to.
        **layout_kwargs: Keyword arguments for layout.

    !!! note
        Figure widgets have currently problems displaying NaNs.
        Use `.show()` method for rendering.

    Example:
        ```py
        import vectorbt as vbt
        import numpy as np

        vbt.plotting.create_volume(
            data=np.random.randint(1, 10, size=(3, 3, 3)),
            x_labels=['a', 'b', 'c'],
            y_labels=['d', 'e', 'f'],
            z_labels=['g', 'h', 'i']
        )
        ```
        ![](/vectorbt/docs/img/create_volume.png)
        """

    if data is None:
        raise ValueError("Data must be passed")
    data = np.asarray(data)
    checks.assert_ndim(data, 3)
    if x_labels is None:
        x_labels = np.arange(data.shape[0])
    if y_labels is None:
        y_labels = np.arange(data.shape[1])
    if z_labels is None:
        z_labels = np.arange(data.shape[2])
    x_labels = np.asarray(x_labels)
    y_labels = np.asarray(y_labels)
    z_labels = np.asarray(z_labels)

    if fig is None:
        fig = CustomFigureWidget()
        fig.update_layout(
            width=700,
            height=450
        )

    # Non-numeric data types are not supported by go.Volume, so use ticktext
    # Note: Currently plotly displays the entire tick array, in future versions it will be more sensible
    if not np.issubdtype(x_labels.dtype, np.number):
        x_ticktext = x_labels
        x_labels = np.arange(data.shape[0])
        fig.update_layout(scene=dict(xaxis=dict(ticktext=x_ticktext, tickvals=x_labels, tickmode='array')))
    if not np.issubdtype(y_labels.dtype, np.number):
        y_ticktext = y_labels
        y_labels = np.arange(data.shape[1])
        fig.update_layout(scene=dict(yaxis=dict(ticktext=y_ticktext, tickvals=y_labels, tickmode='array')))
    if not np.issubdtype(z_labels.dtype, np.number):
        z_ticktext = z_labels
        z_labels = np.arange(data.shape[2])
        fig.update_layout(scene=dict(zaxis=dict(ticktext=z_ticktext, tickvals=z_labels, tickmode='array')))

    # Arrays must have the same length as the flattened data array
    x = np.repeat(x_labels, len(y_labels) * len(z_labels))
    y = np.tile(np.repeat(y_labels, len(z_labels)), len(x_labels))
    z = np.tile(z_labels, len(x_labels) * len(y_labels))

    fig.update_layout(**layout_kwargs)
    volume = go.Volume(
        x=x,
        y=y,
        z=z,
        value=data.flatten(),
        opacity=0.15,
        surface_count=15,  # keep low for big data
        colorscale='Plasma'
    )
    volume.update(**trace_kwargs)
    fig.add_trace(volume)
    return fig
Esempio n. 26
0
def run_many(PGD_attack,
             data_loader,
             model,
             subplot_grid=[2, 2],
             num_adv_directions=1,
             lens=[[-1, 1], [-1, 1], [-1, 1]],
             resolution="high",
             height=1000,
             width=1000,
             show_figure=False,
             save_figure=False,
             file_path='./temp.html',
             specific_class=-1,
             title="",
             if_back_to_cpu=False):

    # Create a figure grid
    fig = make_subplots(rows=subplot_grid[0],
                        cols=subplot_grid[1],
                        specs=[[{
                            'type': 'volume'
                        } for _ in range(subplot_grid[1])]
                               for ind2 in range(subplot_grid[0])])

    num_sub_figures_plotted = 0

    for i, (images, labels) in enumerate(data_loader):

        if if_back_to_cpu:
            images = images.cpu()
            labels = labels.cpu()

        num_figures_3D = subplot_grid[0] * subplot_grid[1]

        if num_sub_figures_plotted < num_figures_3D:

            print(
                f"Plotting figure {num_sub_figures_plotted+1}/{num_figures_3D}."
            )

            if specific_class == -1:

                # This means that we do not need to find a specific class
                img_ind = 0

            else:
                img_ind = find_specific_class(specific_class, labels)
                if img_ind == -1:
                    # This means that this batch does not contain any image of this particular class
                    print("No img of label {0}! Go to the next batch.".format(
                        specific_class))
                    # So, go to the next batch
                    continue

            x = images[img_ind]
            y = labels[img_ind]

            dirs = [0, 0, 0]
            if num_adv_directions == 0:

                print("The number of adversarial directions is 0")

                dirs[0] = torch.rand(x.shape) - 0.5
                dirs[1] = torch.rand(x.shape) - 0.5
                dirs[2] = torch.rand(x.shape) - 0.5

            elif num_adv_directions == 1:

                print("The number of adversarial directions is 1")

                labels_change = torch.randint(1, 10, (labels.shape[0], ))
                wrong_labels = torch.remainder(labels_change + labels, 10)
                adv_images = PGD_attack.__call__(images, wrong_labels)
                dirs[0] = adv_images[img_ind].cpu() - x

                dirs[1] = torch.rand(x.shape) - 0.5
                dirs[2] = torch.rand(x.shape) - 0.5

            elif num_adv_directions == 3:

                print("The number of adversarial directions is 3")

                for dir_ind in range(3):

                    labels_change = torch.ones(labels.shape[0]) * (dir_ind + 1)
                    labels_change = labels_change.long()
                    wrong_labels = torch.remainder(labels_change + labels, 10)
                    adv_images = PGD_attack.__call__(images, wrong_labels)
                    dirs[dir_ind] = adv_images[img_ind].cpu() - x

            else:
                raise NameError(
                    'The number of adversarial directions has to be either 0, 1, or 3.'
                )

            # Normalize the first direction
            dirs[0] = dirs[0] / torch.norm(dirs[0], p=2)

            # Normalize the second direction
            dirs[1] = dirs[1] / torch.norm(dirs[1], p=2)
            dirs[1] = dirs[1] - torch.dot(dirs[1].view(-1),
                                          dirs[0].view(-1)) * dirs[0]
            dirs[1] = dirs[1] / torch.norm(dirs[1], p=2)

            # Normalize the third direction

            dirs[2] = dirs[2] / torch.norm(dirs[2], p=2)
            proj1 = torch.dot(dirs[2].view(-1), dirs[0].view(-1))
            proj2 = torch.dot(dirs[2].view(-1), dirs[1].view(-1))
            dirs[2] = dirs[2] - proj1 * dirs[0] - proj2 * dirs[1]
            dirs[2] = dirs[2] / torch.norm(dirs[2], p=2)

            # Check if the three directions are orthogonal
            Assert_three_orthogonal(dirs)

            # Compute the grid outputs
            x, y, z, value = Compute_grid_outputs(model,
                                                  x,
                                                  y,
                                                  dirs,
                                                  lens=lens,
                                                  resolution=resolution)

            # Figure out where to put the subfigure
            row_ind = int(num_sub_figures_plotted / subplot_grid[1])
            col_ind = num_sub_figures_plotted - row_ind * subplot_grid[1]

            row_ind += 1
            col_ind += 1

            # Add a subfigure
            fig.add_trace(
                go.Volume(
                    x=x,
                    y=y,
                    z=z,
                    value=value,
                    isomin=0,
                    isomax=1,
                    opacity=0.1,  # needs to be small to see through all surfaces
                    surface_count=
                    17,  # needs to be a large number for good volume rendering
                ),
                row=row_ind,
                col=col_ind)

            num_sub_figures_plotted += 1

        else:
            break

    if num_adv_directions == 0:
        title_text = "All three directions are random."
    elif num_adv_directions == 1:
        title_text = "X direction is adversarial."
    elif num_adv_directions == 3:
        title_text = "All three directions are adversarial (with different classes)."
    else:
        raise NameError(
            'The number of adversarial directions has to be either 0, 1, or 3.'
        )

    title_text += " Exp name: "
    title_text += title

    fig.update_layout(height=height, width=width, title_text=title_text)

    if show_figure:
        fig.show()

    if save_figure:
        plotly.offline.plot(fig, filename=file_path)

    return fig
import plotly.graph_objects as go
import numpy as np


filename = 'dump/2020-12-14T01:28:14-14.csv'
values = np.genfromtxt(filename, delimiter=',')

fig = go.Figure(data=go.Volume(
        x=values[:, 0],
        y=values[:, 1],
        z=values[:, 2],
        value=values[:, 3],
        isomin=0,
        isomax=.01,
        opacity=.3,  # needs to be small to see through all surfaces
        surface_count=10,  # pick larger for good volume rendering
    )
)
fig.show()
Esempio n. 28
0
#sz     = 1200*au  # See the full disk
#floor  = 1e-20

sz = 150 * au  # Zoom in
floor = 1e-15

box = np.array([-sz, sz, -sz, sz, -sz, sz])

s = subBox()
s.makeSubbox('rhodust', box, nxyz, phi1=0., theta=0., phi2=0.)
s.readSubbox()

values = np.log10(s.data + floor)
X, Y, Z = np.meshgrid(s.x, s.y, s.z, indexing='ij')

#
# Use plotly library to make a volume rendering of the disk
# https://plotly.com/python/3d-volume-plots/
#
fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=values.flatten(),
    isomin=values.min(),
    isomax=values.max(),
    opacity=0.1,  # needs to be small to see through all surfaces
    surface_count=17,  # needs to be a large number for good volume rendering
))
fig.show()
Esempio n. 29
0
    def __init__(self,
                 data=None,
                 x_labels=None,
                 y_labels=None,
                 z_labels=False,
                 trace_kwargs=None,
                 add_trace_kwargs=None,
                 scene_name='scene',
                 fig=None,
                 **layout_kwargs):
        """Create a volume plot.

        Args:
            data (array_like): Data in any format that can be converted to NumPy.

                Must be a 3-dim array.
            x_labels (array_like): X-axis labels.
            y_labels (array_like): Y-axis labels.
            z_labels (array_like): Z-axis labels.
            trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Volume`.
            add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
            scene_name (str): Reference to the 3D scene.
            fig (plotly.graph_objects.Figure): Figure to add traces to.
            **layout_kwargs: Keyword arguments for layout.

        !!! note
            Figure widgets have currently problems displaying NaNs.
            Use `.show()` method for rendering.

        ## Example

        ```python-repl
        >>> import vectorbt as vbt
        >>> import numpy as np

        >>> volume = vbt.plotting.Volume(
        ...     data=np.random.randint(1, 10, size=(3, 3, 3)),
        ...     x_labels=['a', 'b', 'c'],
        ...     y_labels=['d', 'e', 'f'],
        ...     z_labels=['g', 'h', 'i']
        ... )
        >>> volume.fig
        ```

        ![](/vectorbt/docs/img/Volume.png)
        """
        Configured.__init__(self,
                            data=data,
                            x_labels=x_labels,
                            y_labels=y_labels,
                            z_labels=z_labels,
                            trace_kwargs=trace_kwargs,
                            add_trace_kwargs=add_trace_kwargs,
                            scene_name=scene_name,
                            fig=fig,
                            **layout_kwargs)

        from vectorbt.settings import layout

        if trace_kwargs is None:
            trace_kwargs = {}
        if add_trace_kwargs is None:
            add_trace_kwargs = {}
        if data is None:
            if x_labels is None or y_labels is None or z_labels is None:
                raise ValueError(
                    "At least x_labels, y_labels and z_labels must be passed")
            x_len = len(x_labels)
            y_len = len(y_labels)
            z_len = len(z_labels)
        else:
            checks.assert_ndim(data, 3)
            data = np.asarray(data)
            x_len, y_len, z_len = data.shape
        if x_labels is None:
            x_labels = np.arange(x_len)
        else:
            x_labels = clean_labels(x_labels)
        if y_labels is None:
            y_labels = np.arange(y_len)
        else:
            y_labels = clean_labels(y_labels)
        if z_labels is None:
            z_labels = np.arange(z_len)
        else:
            z_labels = clean_labels(z_labels)
        x_labels = np.asarray(x_labels)
        y_labels = np.asarray(y_labels)
        z_labels = np.asarray(z_labels)

        if fig is None:
            fig = FigureWidget()
            if 'width' in layout:
                # Calculate nice width and height
                fig.update_layout(width=layout['width'],
                                  height=0.7 * layout['width'])

        # Non-numeric data types are not supported by go.Volume, so use ticktext
        # Note: Currently plotly displays the entire tick array, in future versions it will be more sensible
        more_layout = dict()
        if not np.issubdtype(x_labels.dtype, np.number):
            x_ticktext = x_labels
            x_labels = np.arange(x_len)
            more_layout[scene_name] = dict(xaxis=dict(
                ticktext=x_ticktext, tickvals=x_labels, tickmode='array'))
        if not np.issubdtype(y_labels.dtype, np.number):
            y_ticktext = y_labels
            y_labels = np.arange(y_len)
            more_layout[scene_name] = dict(yaxis=dict(
                ticktext=y_ticktext, tickvals=y_labels, tickmode='array'))
        if not np.issubdtype(z_labels.dtype, np.number):
            z_ticktext = z_labels
            z_labels = np.arange(z_len)
            more_layout[scene_name] = dict(zaxis=dict(
                ticktext=z_ticktext, tickvals=z_labels, tickmode='array'))
        fig.update_layout(**more_layout)
        fig.update_layout(**layout_kwargs)

        # Arrays must have the same length as the flattened data array
        x = np.repeat(x_labels, len(y_labels) * len(z_labels))
        y = np.tile(np.repeat(y_labels, len(z_labels)), len(x_labels))
        z = np.tile(z_labels, len(x_labels) * len(y_labels))

        volume = go.Volume(
            x=x,
            y=y,
            z=z,
            opacity=0.2,
            surface_count=15,  # keep low for big data
            colorscale='Plasma')
        volume.update(**trace_kwargs)
        fig.add_trace(volume, **add_trace_kwargs)

        TraceUpdater.__init__(self, fig, [fig.data[-1]])

        if data is not None:
            self.update(data)
Esempio n. 30
0
def visualize3D(vis_net,
                x,
                y,
                dir1,
                dir2,
                dir3,
                len1=1,
                len2=1,
                len3=1,
                show_figure=True,
                save_figure=False,
                file_path='./temp.html'):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Normalize the three directions
    print('Take three orthogonal directions')
    dir1 = dir1 / torch.norm(dir1, p=float('2'))
    dir2 = dir2 / torch.norm(dir2, p=float('2'))
    dir3 = dir3 / torch.norm(dir3, p=float('2'))

    # Check if the three directions are orthogonal to each other
    inner_product1 = torch.abs(torch.dot(dir1.view(-1), dir2.view(-1)))
    inner_product2 = torch.abs(torch.dot(dir1.view(-1), dir3.view(-1)))
    inner_product3 = torch.abs(torch.dot(dir2.view(-1), dir3.view(-1)))

    check_inner_product1 = (inner_product1 < 0.01).item()
    check_inner_product2 = (inner_product2 < 0.01).item()
    check_inner_product3 = (inner_product3 < 0.01).item()

    assert check_inner_product1, "The three directions are not orthogonal"
    assert check_inner_product2, "The three directions are not orthogonal"
    assert check_inner_product3, "The three directions are not orthogonal"

    # Generate the visualization and data grid
    #lenx, leny, lenz = 51, 51, 51
    xx, yy, zz = np.mgrid[-len1:len1:50j, -len2:len2:50j, -len3:len3:50j]

    t = np.c_[xx.ravel(), yy.ravel(), zz.ravel()]
    vis_grid = torch.from_numpy(t).float().to(device)
    dirs_mat = torch.cat(
        [dir1.reshape(1, -1),
         dir2.reshape(1, -1),
         dir3.reshape(1, -1)]).to(device)
    x_grid = torch.mm(vis_grid, dirs_mat).reshape(len(vis_grid), 3, 32,
                                                  32).to('cpu') + x

    grid_output = []
    grid_loader = torch.utils.data.DataLoader(TensorDataset(x_grid),
                                              batch_size=64,
                                              shuffle=False,
                                              num_workers=2)

    vis_net.eval()

    softmax1 = nn.Softmax()

    for grid_points in tqdm(grid_loader):

        grid_points = grid_points[0].to(device)
        grid_ys = vis_net(grid_points)
        grid_ys = softmax1(grid_ys)
        grid_ys = grid_ys[:, y].detach().cpu().numpy()
        grid_output.append(grid_ys)

    y_pred0 = np.concatenate(grid_output)

    # and plot everything
    fig = go.Figure(data=go.Volume(
        x=xx.flatten(),
        y=yy.flatten(),
        z=zz.flatten(),
        value=y_pred0.flatten(),
        isomin=0,
        isomax=1,
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=17,  # needs to be a large number for good volume rendering
    ))

    if show_figure:
        fig.show()

    if save_figure:
        plotly.offline.plot(fig, filename=file_path)

    return fig