Пример #1
0
def plot(traj: SaturationTrajectory,
         outfile: str,
         exist_ok: bool = False,
         xlabel: str = 'Layers') -> None:
    """Plots saturation information through layers to the given folder

    Args:
        traj (SaturationTrajectory): the trajectory to plot
        outfile (str): the zip file to save plots to
        exist_ok (bool, optional): Defaults to False. True to overwrite, False to error
            if the file already exists
        xlabel (str, optional): Defaults to 'Layers'. The label for the x-axis for plots
            that go through layers
    """

    outfile, outfile_wo_ext = mutils.process_outfile(outfile, exist_ok)

    os.makedirs(outfile_wo_ext)
    _plot_boxplot(traj, os.path.join(outfile_wo_ext, 'boxplot.png'), xlabel)
    for identifier in BUCKETING_TECHNIQUES:
        _plot_hist(traj, os.path.join(outfile_wo_ext,
                                      f'hist_{identifier}.png'), xlabel,
                   identifier)
    for num_bins in BUCKETING_SIZES:
        _plot_hist(
            traj,
            os.path.join(outfile_wo_ext, f'hist_fixed_nbins_{num_bins}.png'),
            xlabel, num_bins)

    if exist_ok and os.path.exists(outfile):
        os.remove(outfile)
    zipdir(outfile_wo_ext)
Пример #2
0
def _binned2norm(induced: np.ndarray,
                 outpath: str,
                 title: str,
                 dpi=400,
                 transparent=False):
    """The target for Binned2Norm: bar plot of the induced changes in the 2norm

    Args:
        induced (np.ndarray): a list of floats of induced changes in 2-norm
        outpath (str): the folder or zip file to save to
    """
    _, outfolder = mutils.process_outfile(outpath, False)
    os.makedirs(outfolder, exist_ok=True)

    fig, ax = plt.subplots()
    ax.set_xlabel('Induced $\\Delta \\| W \\|_2$').set_fontsize(16)
    ax.set_ylabel('Count').set_fontsize(16)

    ax.set_title(title).set_fontsize(18)

    ax.hist(induced, bins=10)

    fig.savefig(os.path.join(outfolder, 'histogram.png'),
                dpi=dpi,
                transparent=transparent)
    plt.close(fig)

    zipdir(outfolder)
def plot_avg_pr_trajectories(trajectories: typing.List[TrajectoryWithMeta],
                             savepath: str, title: str, exist_ok: bool = False):
    """Plots multiple participation ratio trajectories on a single figure,
    where each trajectory must be associated with a particular label, where
    each trajectory is actually the average of multiple trajectories

    Arguments:
        trajectories (list[TrajectoryWithMeta]): the trajectories to plot
        savepath (str): the zip file to save the resulting figures in
        title (str): the title for the figure
        exist_ok (bool, default False): True to overwrite existing files, False not to
    """
    if not isinstance(trajectories, (list, tuple)):
        raise ValueError(f'expected trajectories is list or tuple, got {trajectories} (type={type(trajectories)})')
    if not trajectories:
        raise ValueError(f'need at least one trajectory, got empty {type(trajectories)}')
    if not isinstance(trajectories[0], TrajectoryWithMeta):
        raise ValueError(f'expected trajectories[0] is TrajectoryWithMeta, got {trajectories[0]} (type={type(trajectories[0])})')
    layers = trajectories[0].trajectory.layers
    depth = trajectories[0].trajectory.overall.shape[0]
    if not isinstance(title, str):
        raise ValueError(f'expected title is str, got {title} (type={type(title)})')
    for i, traj in enumerate(trajectories):
        if not isinstance(traj, TrajectoryWithMeta):
            raise ValueError(f'expected trajectories[{i}] is TrajectoryWithMeta, got {traj} (type={type(traj)})')
        if traj.trajectory.layers != layers:
            raise ValueError(f'trajectories[0].trajectory.layers = {layers}, trajectories[{i}].trajectory.layers = {traj.trajectory.layers}')
        _depth = traj.trajectory.overall.shape[0]
        if depth != _depth:
            raise ValueError(f'trajectories[0].trajectory.overall.shape[0] = {depth}, trajectories[{i}].trajectory.overall.shape[0] = {_depth}')

    filename, folder = mutils.process_outfile(savepath, exist_ok)
    os.makedirs(folder, exist_ok=True)

    fig, ax = plt.subplots()
    ax.set_title(title).set_fontsize(18)
    ax.set_xlabel('Layer' if layers else 'Time').set_fontsize(16)
    ax.set_ylabel('Participation Ratio').set_fontsize(16)
    ax.set_xticks([i for i in range(depth)])

    my_cmap = plt.get_cmap('Set1')
    cols = my_cmap([i for i in range(len(trajectories))])
    x_vals = np.arange(depth)
    for ind, traj_meta in enumerate(trajectories):
        traj = traj_meta.trajectory
        ax.errorbar(x_vals, traj.overall.numpy(), yerr=traj.overall_sem.numpy()*1.96, color=cols[ind], label=traj_meta.label)
    ax.legend()

    fig.savefig(os.path.join(folder, 'out.png'))
    plt.close(fig)

    if os.path.exists(filename):
        os.remove(filename)
    zipdir(folder)
Пример #4
0
def save_using(samples: np.ndarray, labels: np.ndarray, *layer_acts: typing.Tuple[np.ndarray],
               num_labels: int, outpath: str, exist_ok: bool, meta: dict,
               **additional: typing.Dict[str, np.ndarray]):
    """Stores the activations of the network to the given file, optionally
    overwriting it if it already exists.

    Args:
        samples (np.ndarray): the samples presented to the network of dimensions
            [num_samples, input_dim]
        labels (np.ndarray): the labels corresponding to the samples presented
            [num_samples]
        layer_acts (tuple[np.ndarray]): the activations of the network. each element
            corresponds to an array of activations with dimensions
            [num_samples, layer_size]
        outpath (str): the file to save to, should be a zip file
        exist_ok (bool): True to overwrite existing files, False not to
        meta (dict): saved alongside the data in json-format
        additional (dict[str, ndarray]): any additional arrays to save
    """
    filepath, folderpath = mutils.process_outfile(outpath, exist_ok)

    os.makedirs(folderpath, exist_ok=True)

    label_masks = [labels == val for val in range(num_labels)]

    asdict = dict({'samples': samples, 'labels': labels}, **additional)
    layers_stacked = None
    for layer, act in enumerate(layer_acts):
        if layer > 0 and layer < len(layer_acts):
            if layers_stacked is None:
                layers_stacked = np.expand_dims(act, 0)
            elif act.shape[0] == layers_stacked.shape[1] and act.shape[1] == layers_stacked.shape[2]:
                layers_stacked = np.concatenate((layers_stacked, np.expand_dims(act, 0)), axis=0)

        asdict[f'layer_{layer}'] = act
        for label, mask in enumerate(label_masks):
            asdict[f'layer_{layer}_label_{label}'] = act[mask]

    asdict['layers_stacked'] = layers_stacked
    scipy.io.savemat(os.path.join(folderpath, 'all'), asdict) # pylint: disable=no-member
    np.savez(os.path.join(folderpath, 'all'), **asdict)

    if SAVE_SPLIT:
        for key, val in asdict.items():
            scipy.io.savemat(os.path.join(folderpath, key), {key: val}) # pylint: disable=no-member
            np.savez(os.path.join(folderpath, key), val)

    scipy.io.savemat(os.path.join(folderpath, 'meta'), meta) # pylint: disable=no-member
    with open(os.path.join(folderpath, 'meta.json'), 'w') as outfile:
        json.dump(meta, outfile)

    if os.path.exists(filepath):
        os.remove(filepath)
    filetools.zipdir(folderpath)
    def save(self, outfile: str, exist_ok=False):
        """Saves this trajaectory to the given file

        Args:
            outfile (str): the filename to save to; should be a zip file
            exist_ok (bool): True to overwrite outfile if it exists, False not to
        """
        _, folder = mutils.process_outfile(outfile, exist_ok=exist_ok)
        os.makedirs(folder, exist_ok=True)

        meta_dict = {'layers': self.layers}
        with open(os.path.join(folder, 'meta.json'), 'w') as metaout:
            json.dump(meta_dict, metaout)
        torch.save(self.overall, os.path.join(folder, 'overall.pt'))
        if self.by_label is not None:
            torch.save(self.by_label, os.path.join(folder, 'by_label.pt'))
        zipdir(folder)
    def load(cls, infile: str):
        """Loads the PR trajectory saved to the given filepath

        Arguments:
            infile (str): the filename to load from; should be a zip file
        """
        filename, folder = mutils.process_outfile(infile, exist_ok=True)
        if not os.path.exists(filename):
            raise FileNotFoundError(filename)
        unzip(filename)

        with open(os.path.join(folder, 'meta.json'), 'r') as meta_in:
            meta_dict = json.load(meta_in)
        overall = torch.load(os.path.join(folder, 'overall.pt'))
        by_label = None
        if os.path.exists(os.path.join(folder, 'by_label.pt')):
            by_label = torch.load(os.path.join(folder, 'by_label.pt'))
        zipdir(folder)
        return cls(overall=overall, layers=meta_dict['layers'], by_label=by_label)
Пример #7
0
    def load(cls, filepath: str, compress: bool = True):
        """Loads the clusters located in the given filepath. If the filepath has
        an extension it must be .zip and it will be ignored. This will first check
        if the folder exists and then the archive.

        Arguments:
            filepath (str): the path to the folder or archive that the clusters were saved in
            compress (bool): if True the folder will be compressed after this is done,
                regardless of the old state. If this is False, the folder will not be
                compressed after this is done, regardless of the old state.
        """

        outfile, outfile_wo_ext = mutils.process_outfile(filepath, True, False)

        if not os.path.exists(outfile_wo_ext):
            if not os.path.exists(outfile):
                raise FileNotFoundError(filepath)
            filetools.unzip(outfile)

        try:
            clusters_path = os.path.join(outfile_wo_ext, 'clusters.npz')
            if not os.path.exists(clusters_path):
                raise FileNotFoundError(clusters_path)

            calc_params_path = os.path.join(outfile_wo_ext,
                                            'calculate_params.json')
            if not os.path.exists(calc_params_path):
                raise FileNotFoundError(calc_params_path)

            with np.load(clusters_path) as clusters:
                samples = clusters['samples']
                centers = clusters['centers']
                labels = clusters['labels']

            with open(calc_params_path, 'r') as infile:
                calculate_params = json.load(infile)

            return Clusters(samples, centers, labels, calculate_params)
        finally:
            if compress and os.path.exists(outfile_wo_ext):
                filetools.zipdir(outfile_wo_ext)
Пример #8
0
    def __init__(self,
                 output_path: str,
                 layer_names: typing.List[str],
                 exist_ok: bool = False,
                 layer_indices=None):
        """
        Args:
            output_path (str): either the output folder or the output archive
            layer_names (str): a list of layer names starting with 'input' and
                ending with 'output'
            exist_ok (bool, optional): Defaults to False. If True, if the output
                archive already exists it will be deleted. Otherwise, if the output
                archive already exists an error will be raised
            layer_indices (list[int], optional): Default to None. If specified, only the
                given layers are rendered (where 0 is the input and -1 is output). Otherwise
                defaults to [1:]
        """

        _, self.output_folder = mutils.process_outfile(output_path, exist_ok)
        self.layer_names = layer_names
        self.layer_indices = layer_indices

        self.connections = None
        self.skip_counter = None

        self.batch_size = None
        self.sample_labels = None
        self.sample_points = None
        self.layers = None

        self.sample_labels_torch = None
        self.sample_points_torch = None

        self.sample_labels_file = None
        self.hid_acts_files = None

        self.dpi = 100
        self.fps = FPS
        self.frame_size = FRAME_SIZE
Пример #9
0
def plot_trajectory(traj: pca_gen.PCTrajectoryGen,
                    filepath: str,
                    exist_ok: bool = False,
                    markers: typing.List[str] = ('<', '>', '^', 'v'),
                    cmap: typing.Union[mcolors.Colormap, str] = 'cividis',
                    norm: mcolors.Normalize = mcolors.Normalize(-1, 1),
                    transparent: bool = False):
    """Plots the given trajectory (from a deep2-style network) to the given
    folder.

    Arguments:
        traj (PCTrajectoryGen): the trajectory to plot
        filepath (str): where to save the output, should be a folder
        exist_ok (bool): False to error if the filepath exists, True to delete it
            if it already exists
        markers (list[str]): the marker corresponding to each preferred action
        cmap (str or Colormap, optional): The color map to use. Defaults to 'cividis'.
        norm (mcolors.Normalize, optional): Normalizes the scalars that are passed to the color
            map to the range 0-1. Defaults to normalizing linearly from [-1, 1] to [0, 1]
        transparent (bool): True for a transparent background, False for a white one
    """
    tus.check(
        traj=(traj, pca_gen.PCTrajectoryGen),
        filepath=(filepath, str),
        exist_ok=(exist_ok, bool),
    )
    tus.check_listlike(markers=(markers, str))

    ots = pca_gen.MaxOTSMapping()
    s = 12
    alpha = 0.8

    outfile_wo_ext = mutils.process_outfile(filepath, exist_ok, False)[1]
    if exist_ok and os.path.exists(outfile_wo_ext):
        filetools.deldir(outfile_wo_ext)

    os.makedirs(outfile_wo_ext)

    num_splots_req = traj.num_layers + 1
    closest_square: int = int(np.ceil(np.sqrt(num_splots_req)))
    num_cols: int = int(math.ceil(num_splots_req / closest_square))
    local_fig, local_axs = plt.subplots(num_cols,
                                        closest_square,
                                        squeeze=False,
                                        figsize=FRAME_SIZE)

    layer: int = 0
    for x in range(num_cols):
        for y in range(closest_square):
            if layer >= num_splots_req:
                local_axs[x][y].remove()
                continue
            elif layer >= traj.num_layers:
                lspace = np.linspace(norm.vmin, norm.vmax, 100)
                axis = local_axs[x][y]
                axis.tick_params(axis='both',
                                 which='both',
                                 bottom=False,
                                 left=False,
                                 top=False,
                                 labelbottom=False,
                                 labelleft=False)
                axis.imshow(lspace[..., np.newaxis],
                            cmap=cmap,
                            norm=norm,
                            aspect=0.2)
                layer += 1
                continue
            snapshot: pca_gen.PCTrajectoryGenSnapshot = traj[layer]

            projected = snapshot.projected_samples
            projected_lbls = snapshot.projected_sample_labels

            min_x, min_y, max_x, max_y = (torch.min(projected[:, 0]),
                                          torch.min(projected[:, 1]),
                                          torch.max(projected[:, 0]),
                                          torch.max(projected[:, 1]))
            min_x, min_y, max_x, max_y = min_x.item(), min_y.item(
            ), max_x.item(), max_y.item()

            if max_x - min_x < 1e-3:
                min_x -= 5e-4
                max_x += 5e-4
            if max_y - min_y < 1e-3:
                min_y -= 5e-4
                max_y += 5e-4
            extents_x = max_x - min_x
            extents_y = max_y - min_y
            if extents_x > extents_y:
                upd = (extents_x - extents_y) / 2
                min_y -= upd
                max_y += upd
            else:
                upd = (extents_y - extents_x) / 2
                min_x -= upd
                max_x += upd
            padding_x = (max_x - min_x) * .1
            padding_y = (max_y - min_y) * .1

            vis_min_x = min_x - padding_x
            vis_max_x = max_x + padding_x
            vis_min_y = min_y - padding_y
            vis_max_y = max_y + padding_y

            markers_selected = projected_lbls.max(dim=1)[1]
            axis = local_axs[x][y]
            for marker_ind, marker in enumerate(markers):
                marker_projected = projected[markers_selected == marker_ind]
                marker_projected_lbls = projected_lbls[markers_selected ==
                                                       marker_ind]
                projected_colors = ots(marker_projected_lbls)
                axis.scatter(marker_projected[:, 0].numpy(),
                             marker_projected[:, 1].numpy(),
                             s=s,
                             alpha=alpha,
                             c=projected_colors.numpy(),
                             cmap=mcm.get_cmap(cmap),
                             norm=norm,
                             marker=marker)

            axis.set_xlim([vis_min_x, vis_max_x])
            axis.set_ylim([vis_min_y, vis_max_y])
            axis.tick_params(axis='both',
                             which='both',
                             bottom=False,
                             left=False,
                             top=False,
                             labelbottom=False,
                             labelleft=False)
            layer += 1

    local_path = os.path.join(outfile_wo_ext, 'local.png')
    local_fig.tight_layout()
    local_fig.savefig(local_path, transparent=transparent, DPI=DPI)

    np.savez(os.path.join(outfile_wo_ext, 'principal_vectors.npz'),
             *[snapshot.principal_vectors for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'principal_values.npz'),
             *[snapshot.principal_values for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_samples.npz'),
             *[snapshot.projected_samples for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_sample_labels.npz'),
             *[snapshot.projected_sample_labels for snapshot in traj])
Пример #10
0
    def save(self,
             filepath: str,
             exist_ok: bool = False,
             compress: bool = True) -> None:
        """Saves these clusters along with a description about how to load them
        to the given filepath. If the filepath has an extension, it must be .zip
        and it will be ignored in favor of compress.

        Arguments:
            filepath (str): the folder or zip file where these clusters should be
                saves
            exist_ok (bool): effects the behavior if the folder or zip file already
                exists. If this is False, then an error is thrown. If this is True,
                the existing files are deleted
            compress (bool): if True, the folder is compressed to a zip file after
                saving and the folder is deleted. If False, the result is left as a
                folder
        """

        outfile, outfile_wo_ext = mutils.process_outfile(
            filepath, exist_ok, compress)

        if os.path.exists(outfile_wo_ext):
            filetools.deldir(outfile_wo_ext)

        os.makedirs(outfile_wo_ext)

        np.savez_compressed(os.path.join(outfile_wo_ext, 'clusters.npz'),
                            samples=self.samples,
                            centers=self.centers,
                            labels=self.labels)

        with open(os.path.join(outfile_wo_ext, 'calculate_params.json'),
                  'w') as out:
            json.dump(self.calculate_params, out)

        with open(os.path.join(outfile_wo_ext, 'readme.md'), 'w') as out:

            def _print(*args, **kwargs):
                print(*args, **kwargs, file=out)

            _print('Clusters')
            _print('  clusters.npz:')
            _print(
                '    samples [n_samples, n_features] - the samples the clusters were calculated'
                + ' from')
            _print(
                '    centers [n_clusters, n_features] - the centers of the clusters'
            )
            _print(
                '    labels [n_samples] - the index in centers for the closest cluster '
                + 'to each label')
            _print('  calculate_params.json:')
            _print(
                '    Varies. Gives information about how clusters were calculated'
            )

        if compress:
            if os.path.exists(outfile):
                os.remove(outfile)
            filetools.zipdir(outfile_wo_ext)
Пример #11
0
def plot_trajectory(traj: PCTrajectoryGen, filepath: str, exist_ok: bool = False,
                    alpha: float = 0.5, square: bool = True, transparent: bool = True,
                    s: int = 1, ots: OutputToScalarMapping = SqueezeOTSMapping(),
                    cmap: typing.Union[mcolors.Colormap, str] = 'cividis',
                    norm: mcolors.Normalize = mcolors.Normalize(-1, 1),
                    compress: bool = False):
    """Plots the given trajectory by storing it in the given filepath. If the output of
    the trajectory is not itself a scalar, the output to scalar mapping must be set.
    The other arguments are related to display.

    Args:
        traj (PCTrajectoryGen): The trajectory to plot. Must have at least 2 pcs
        filepath (str): Where to store the given trajectory, either a folder or a zip file.
            The file zip extension will only be used if compress is true
        exist_ok (bool, optional): If the filepath already exists, then this determines if it
            should be overwritten (True) or an error should be raised (False). Defaults to False.
        alpha (float, optional): The transparency value for each vector. Defaults to 0.5.
        square (bool, optional): If the dimensions of the space should be equal for width and
            height (such that 1 inch width and height visually corresponds to the same amount of
            distance in pc-space). Since pc space is naturally rectangular, not setting this
            can easily lead to misinterpretations. Defaults to True.
        transparent (bool, optional): Determines the background color of the saved images, where
            True is transparency and False is near-white. Defaults to True.
        s (int, optional): The size of each projected sample. Defaults to 1.
        ots (OutputToScalarMapping, optional): Maps the labels of the trajectory to samples which
            are then converted to colors using the color map. Defaults to SqueezeOTSMapping().
        cmap (str or Colormap, optional): The color map to use. Defaults to 'cividis'.
        norm (mcolors.Normalize, optional): Normalizes the scalars that are passed to the color
            map to the range 0-1. Defaults to normalizing linearly from [-1, 1] to [0, 1]
        compress (bool): if the folder should be zipped
    """
    tus.check(
        traj=(traj, PCTrajectoryGen),
        filepath=(filepath, str),
        exist_ok=(exist_ok, bool),
        alpha=(alpha, float),
        square=(square, bool),
        transparent=(transparent, bool),
        s=(s, int),
        ots=(ots, OutputToScalarMapping),
        cmap=(cmap, (str, mcolors.Colormap))
    )

    outfile, outfile_wo_ext = mutils.process_outfile(filepath, exist_ok, compress)
    if not compress and exist_ok and os.path.exists(outfile_wo_ext):
        filetools.deldir(outfile_wo_ext)
    os.makedirs(outfile_wo_ext)

    num_splots_req = traj.num_layers + 1
    closest_square: int = int(np.ceil(np.sqrt(num_splots_req)))
    num_cols: int = int(math.ceil(num_splots_req / closest_square))
    local_fig, local_axs = plt.subplots(num_cols, closest_square, squeeze=False, figsize=FRAME_SIZE)

    layer: int = 0
    for x in range(num_cols):
        for y in range(closest_square):
            if layer >= num_splots_req:
                local_axs[x][y].remove()
                continue
            elif layer >= traj.num_layers:
                lspace = np.linspace(norm.vmin, norm.vmax, 100)
                axis = local_axs[x][y]
                axis.tick_params(axis='both', which='both', bottom=False, left=False, top=False,
                                 labelbottom=False, labelleft=False)
                axis.imshow(lspace[..., np.newaxis], cmap=cmap, norm=norm, aspect=0.2)
                layer += 1
                continue
            snapshot: PCTrajectoryGenSnapshot = traj[layer]

            projected = snapshot.projected_samples
            projected_lbls = snapshot.projected_sample_labels

            min_x, min_y, max_x, max_y = (torch.min(projected[:, 0]), torch.min(projected[:, 1]),
                                          torch.max(projected[:, 0]), torch.max(projected[:, 1]))
            min_x, min_y, max_x, max_y = min_x.item(), min_y.item(), max_x.item(), max_y.item()

            if max_x - min_x < 1e-3:
                min_x -= 5e-4
                max_x += 5e-4
            if max_y - min_y < 1e-3:
                min_y -= 5e-4
                max_y += 5e-4
            if square:
                extents_x = max_x - min_x
                extents_y = max_y - min_y
                if extents_x > extents_y:
                    upd = (extents_x - extents_y) / 2
                    min_y -= upd
                    max_y += upd
                else:
                    upd = (extents_y - extents_x) / 2
                    min_x -= upd
                    max_x += upd
            padding_x = (max_x - min_x) * .1
            padding_y = (max_y - min_y) * .1

            vis_min_x = min_x - padding_x
            vis_max_x = max_x + padding_x
            vis_min_y = min_y - padding_y
            vis_max_y = max_y + padding_y

            projected_colors = ots(projected_lbls)
            axis = local_axs[x][y]
            axis.scatter(projected[:, 0].numpy(), projected[:, 1].numpy(),
                         s=s, alpha=alpha, c=projected_colors.numpy(),
                         cmap=mcm.get_cmap(cmap), norm=norm)
            axis.set_xlim([vis_min_x, vis_max_x])
            axis.set_ylim([vis_min_y, vis_max_y])
            axis.tick_params(axis='both', which='both', bottom=False, left=False, top=False,
                             labelbottom=False, labelleft=False)
            layer += 1

    local_path = os.path.join(outfile_wo_ext, 'local.png')
    local_fig.tight_layout()
    local_fig.savefig(local_path, transparent=transparent, DPI=DPI)

    np.savez(os.path.join(outfile_wo_ext, 'principal_vectors.npz'),
             *[snapshot.principal_vectors for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'principal_values.npz'),
             *[snapshot.principal_values for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_samples.npz'),
             *[snapshot.projected_samples for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_sample_labels.npz'),
             *[snapshot.projected_sample_labels for snapshot in traj])

    if compress:
        if os.path.exists(outfile):
            os.remove(outfile)

        filetools.zipdir(outfile_wo_ext)
Пример #12
0
def plot(traj: MyTrajectory, outfile: str, exist_ok: bool = False) -> None:
    outfile, outfile_wo_ext = mutils.process_outfile(outfile, exist_ok)
    raise NotImplementedError()