Ejemplo n.º 1
0
def _cleanup_session(settings: TrainSettings):
    if settings.current_session.holdover <= 0:
        filetools.deldir(settings.replay_folder)
        time.sleep(0.5)
        return

    os.rename(settings.replay_folder, HOLDOVER_DIR)
    rb.ensure_max_length(HOLDOVER_DIR, settings.current_session.holdover)
Ejemplo n.º 2
0
def merge_many(outpath: str, *paths, auto_open_zips=True):
    """Merges the result of multiple runs. Copies select npz and mat files from
    the given folders, stacks them on dimension 0, and then outputs them to the
    specified output path.

    Example:
        merge_many('repeats/all/epoch_finished', *[f'repeats/repeat{i}/epoch_finished' for i in range(10)])

    Args:
        outpath (str): where the final merged files are stored
        paths (tuple[str]): the paths passed to save_using that will be merged
        auto_open_zips (bool, default True): if True then if we come across a directory that
            doesn't exist in paths we will check if the corresponding zip does exist. If so,
            we extract, fetch, and rezip
    """

    if not paths:
        raise ValueError(f'must have at least one path!')

    if os.path.exists(outpath):
        filetools.deldir(outpath)

    os.makedirs(outpath)

    cur_all = None
    for path in paths:
        to_rezip = []
        if auto_open_zips:
            to_rezip = filetools.recur_unzip(path)

        with np.load(os.path.join(path, 'all.npz')) as allnp:
            if cur_all is None:
                cur_all = dict()
                for k in allnp.keys():
                    cur_all[k] = np.expand_dims(allnp[k], 0)
            else:
                cur_all: dict
                for k in allnp.keys():
                    if k in cur_all:
                        if allnp[k].shape != cur_all[k].shape[1:]:
                            print(f'allnp[{k}].shape = {allnp[k].shape}, cur_all[k].shape = {cur_all[k].shape}; path={path}')
                        cur_all[k] = np.concatenate((cur_all[k], np.expand_dims(allnp[k], 0)), axis=0)

                for k in cur_all:
                    if k not in allnp.keys():
                        del cur_all[k]

        if auto_open_zips:
            filetools.zipmany(*to_rezip)

    scipy.io.savemat(os.path.join(outpath, 'all'), cur_all) # pylint: disable=no-member
    np.savez(os.path.join(outpath, 'all'), **cur_all)

    if SAVE_SPLIT:
        for key, val in cur_all.items():
            scipy.io.savemat(os.path.join(outpath, key), {key: val}) # pylint: disable=no-member
            np.savez(os.path.join(outpath, key), val)
Ejemplo n.º 3
0
    def save(self, outpath: str, exist_ok: bool = False):
        """Saves this model to the given outpath. The outpath should be a folder. This will
        store the following things:
            model.pt - an equivalent network which can be loaded with just pytorch. This is
                accomplished by replacing the more memory efficient EvaluatingAbsoluteNormLayer
                with a linear layer with lots of 0s
            layers.npz - stores the fully connected layers with lyr_weight_i and lyr_bias_i, storing
                the norms in norm_means_i and norm_inv_std_i, where norm_inv_std_i is 1/std(feature)
            layers.mat - stores the equivalent data as layers.npz in matlab format
            readme.txt - relevant documentation for loading the model
        """
        if os.path.exists(outpath):
            if not exist_ok:
                raise FileExistsError(outpath)
            if not os.path.isdir(outpath):
                raise ValueError(
                    f'expected outpath is dir, got {outpath} (not isdir)')
            filetools.deldir(outpath)

        os.makedirs(outpath)
        equiv_net = torch.nn.Sequential(
            self.anorms[0].to_linear(),
            cp_utils.copy_linear(self.fc_layers[0]),
            torch.nn.Tanh(), self.anorms[1].to_linear(),
            cp_utils.copy_linear(self.fc_layers[1]),
            torch.nn.Tanh(), self.anorms[2].to_linear(),
            cp_utils.copy_linear(self.fc_layers[2]), torch.nn.Tanh())
        torch.save(equiv_net, os.path.join(outpath, 'model.pt'))
        del equiv_net

        layers = {}
        for i, norm in enumerate(self.anorms):
            layers[f'norm_means_{i}'] = norm.means.clone().numpy()
            layers[f'norm_inv_std_{i}'] = norm.inv_std.clone().numpy()
        for i, lyr in enumerate(self.fc_layers):
            layers[f'lyr_weight_{i}'] = lyr.weight.data.clone().numpy()
            layers[f'lyr_bias_{i}'] = lyr.bias.data.clone().numpy()
        np.savez_compressed(os.path.join(outpath, 'layers.npz'), **layers)
        scipy.io.savemat(os.path.join(outpath, 'layers.mat'), layers)

        with open(os.path.join(outpath, 'readme.txt'), 'w') as outfile:
            print('Model: Deep1ModelEval', file=outfile)
            print(f'Date: {datetime.datetime.now()}', file=outfile)
            print('Constants:', file=outfile)
            for nm, const in {
                    'ALPHA': ALPHA,
                    'CUTOFF': CUTOFF,
                    'ENCODE_DIM': ENCODE_DIM,
                    'HIDDEN_DIM': HIDDEN_DIM
            }.items():
                print(f'  {nm}: {const}', file=outfile)
            print('Class Documentation:', file=outfile)
            print(Deep1ModelEval.__doc__, file=outfile)
            print(file=outfile)
            print('Function Documentation: ', file=outfile)
            print(Deep1ModelEval.save.__doc__, file=outfile)
Ejemplo n.º 4
0
def _get_experiences_async(settings: TrainSettings, executable: str, port_min: int, port_max: int,
                           create_flags: int, aggressive: bool, spec: bool, nthreads: int):
    num_ticks_to_do = settings.current_session.tar_ticks
    if os.path.exists(settings.replay_folder):
        replay = rb.FileReadableReplayBuffer(settings.replay_folder)
        num_ticks_to_do -= len(replay)
        replay.close()

        if num_ticks_to_do <= 0:
            print(f'get_experiences_async nothing to do (already at {settings.replay_folder})')
            return

    replay_paths = [os.path.join(settings.bot_folder, f'replay_{i}') for i in range(nthreads)]
    setting_paths = [os.path.join(settings.bot_folder, f'settings_{i}.json')
                     for i in range(nthreads)]
    workers = []
    serd_settings = ser.serialize_embeddable(settings)
    ports_per = (port_max - port_min) // nthreads
    if ports_per < 3:
        raise ValueError('not enough ports assigned '
                         + f'({nthreads} threads, {port_max-port_min} ports)')
    ticks_per = int(math.ceil(num_ticks_to_do / nthreads))
    for worker in range(nthreads):
        proc = Process(target=_get_experiences_target,
                       args=(serd_settings, executable, port_min + worker*ports_per,
                             port_min + (worker+1)*ports_per, create_flags, aggressive, spec,
                             replay_paths[worker], setting_paths[worker], ticks_per))
        proc.start()
        workers.append(proc)
        time.sleep(1)

    for proc in workers:
        proc.join()

    print(f'get_experiences_async finished, storing in {settings.replay_folder}')
    if os.path.exists(settings.replay_folder):
        filetools.deldir(settings.replay_folder)

    if os.path.exists(settings.replay_folder):
        tmp_replay_folder = settings.replay_folder + '_tmp'
        os.rename(settings.replay_folder, tmp_replay_folder)
        replay_paths.append(tmp_replay_folder)

    if os.path.exists(HOLDOVER_DIR):
        replay_paths.append(HOLDOVER_DIR)

    rb.merge_buffers(replay_paths, settings.replay_folder)

    for path in replay_paths:
        filetools.deldir(path)
Ejemplo n.º 5
0
def _cache_markers(markers: typing.List[typing.Tuple[np.ndarray, str]]):
    """Stores the given mask and marker combination so that it will be loaded
    by _mark_cached_moves and returned"""
    if os.path.exists(STORED_MARKER_FP):
        filetools.deldir(STORED_MARKER_FP)
    os.makedirs(STORED_MARKER_FP)
    metafile = os.path.join(STORED_MARKER_FP, 'meta.json')
    with open(metafile, 'w') as outfile:
        json.dump({
            'markers': list(mark for _, mark in markers)
        }, outfile)

    np.savez_compressed(
        os.path.join(STORED_MARKER_FP, 'masks.npz'),
        **dict((f'mask_{i}', mask) for i, (mask, _) in enumerate(markers)))
Ejemplo n.º 6
0
    def on_move(self, game_state: GameState, move: Move) -> None:
        self.gstate_cache.put(game_state)

        if self.save_activations:
            qbot: HiddenStateQBot = self.qbot
            acts = qbot.get_hidden(game_state, move)
            self.activations.append_acts(acts)
            if self.activations.num_pts == self.activations_per_block:
                save_folder = os.path.join(self.activations_folder,
                                           self.activations_block)
                if os.path.exists(save_folder):
                    deldir(save_folder)

                self.activations.save(save_folder)
                self.activations_block += 1
                self.activations.num_pts = 0

        if not self.frozen:
            self.history.append((game_state.tick, move))
            if len(self.history) >= self.qbot.cutoff + 1:
                self._teach()
Ejemplo n.º 7
0
def main():
    """Main entry for tests"""

    if os.path.exists(FILEPATH):
        filetools.deldir(FILEPATH)

    buf = rb.FileWritableReplayBuffer(os.path.join(FILEPATH, '1'), exist_ok=False)

    sbuf = []

    for _ in range(5):
        exp = make_exp()
        buf.add(exp)
        sbuf.append(exp)

    buf2 = rb.FileWritableReplayBuffer(os.path.join(FILEPATH, '2'), exist_ok=False)

    for _ in range(5):
        exp = make_exp()
        buf2.add(exp)
        sbuf.append(exp)

    buf.close()
    buf2.close()

    rb.merge_buffers([os.path.join(FILEPATH, '2'), os.path.join(FILEPATH, '1')],
                     os.path.join(FILEPATH, '3'))

    buf.close()
    buf = rb.FileReadableReplayBuffer(os.path.join(FILEPATH, '3'))

    for _ in range(3):
        missing = [exp for exp in sbuf]
        for _ in range(10):
            got = buf.sample(1)[0]
            for i in range(len(missing)): #pylint: disable=consider-using-enumerate
                if got == missing[i]:
                    missing.pop(i)
                    break
            else:
                raise ValueError(f'got bad value: {got} expected one of \n'
                                 + '\n'.join(repr(exp) for exp in missing))

    buf.mark()
    got = buf.sample(1)[0]
    buf.reset()
    got2 = buf.sample(1)[0]
    if got != got2:
        raise ValueError(f'mark did not retrieve same experience: {got} vs {got2}')

    buf.close()

    buf = rb.MemoryPrioritizedReplayBuffer(os.path.join(FILEPATH, '3'))

    saw = []
    buf.mark()
    for _ in range(15):
        got = buf.sample(1)[0]
        saw.append(got)
    buf.reset()
    for _ in range(15):
        got = buf.sample(1)[0]
        if got != saw[0]:
            raise ValueError(f'got bad value: {got}, expected {saw[-1]}')
        saw.pop(0)

    for _ in range(15):
        got = buf.pop()[2]
        found = False
        for exp in sbuf:
            if got == exp:
                found = True
                got.last_td_error = random.random()
                exp.last_td_error = got.last_td_error
                buf.add(got)
                break
        if not found:
            raise ValueError(f'got {got}, expected one of '
                             + '\n'.join(repr(exp) for exp in sbuf))

    buf.close()
Ejemplo n.º 8
0
def plot_trajectory(traj: pca_gen.PCTrajectoryGen,
                    filepath: str,
                    exist_ok: bool = False,
                    markers: typing.List[str] = ('<', '>', '^', 'v'),
                    cmap: typing.Union[mcolors.Colormap, str] = 'cividis',
                    norm: mcolors.Normalize = mcolors.Normalize(-1, 1),
                    transparent: bool = False):
    """Plots the given trajectory (from a deep2-style network) to the given
    folder.

    Arguments:
        traj (PCTrajectoryGen): the trajectory to plot
        filepath (str): where to save the output, should be a folder
        exist_ok (bool): False to error if the filepath exists, True to delete it
            if it already exists
        markers (list[str]): the marker corresponding to each preferred action
        cmap (str or Colormap, optional): The color map to use. Defaults to 'cividis'.
        norm (mcolors.Normalize, optional): Normalizes the scalars that are passed to the color
            map to the range 0-1. Defaults to normalizing linearly from [-1, 1] to [0, 1]
        transparent (bool): True for a transparent background, False for a white one
    """
    tus.check(
        traj=(traj, pca_gen.PCTrajectoryGen),
        filepath=(filepath, str),
        exist_ok=(exist_ok, bool),
    )
    tus.check_listlike(markers=(markers, str))

    ots = pca_gen.MaxOTSMapping()
    s = 12
    alpha = 0.8

    outfile_wo_ext = mutils.process_outfile(filepath, exist_ok, False)[1]
    if exist_ok and os.path.exists(outfile_wo_ext):
        filetools.deldir(outfile_wo_ext)

    os.makedirs(outfile_wo_ext)

    num_splots_req = traj.num_layers + 1
    closest_square: int = int(np.ceil(np.sqrt(num_splots_req)))
    num_cols: int = int(math.ceil(num_splots_req / closest_square))
    local_fig, local_axs = plt.subplots(num_cols,
                                        closest_square,
                                        squeeze=False,
                                        figsize=FRAME_SIZE)

    layer: int = 0
    for x in range(num_cols):
        for y in range(closest_square):
            if layer >= num_splots_req:
                local_axs[x][y].remove()
                continue
            elif layer >= traj.num_layers:
                lspace = np.linspace(norm.vmin, norm.vmax, 100)
                axis = local_axs[x][y]
                axis.tick_params(axis='both',
                                 which='both',
                                 bottom=False,
                                 left=False,
                                 top=False,
                                 labelbottom=False,
                                 labelleft=False)
                axis.imshow(lspace[..., np.newaxis],
                            cmap=cmap,
                            norm=norm,
                            aspect=0.2)
                layer += 1
                continue
            snapshot: pca_gen.PCTrajectoryGenSnapshot = traj[layer]

            projected = snapshot.projected_samples
            projected_lbls = snapshot.projected_sample_labels

            min_x, min_y, max_x, max_y = (torch.min(projected[:, 0]),
                                          torch.min(projected[:, 1]),
                                          torch.max(projected[:, 0]),
                                          torch.max(projected[:, 1]))
            min_x, min_y, max_x, max_y = min_x.item(), min_y.item(
            ), max_x.item(), max_y.item()

            if max_x - min_x < 1e-3:
                min_x -= 5e-4
                max_x += 5e-4
            if max_y - min_y < 1e-3:
                min_y -= 5e-4
                max_y += 5e-4
            extents_x = max_x - min_x
            extents_y = max_y - min_y
            if extents_x > extents_y:
                upd = (extents_x - extents_y) / 2
                min_y -= upd
                max_y += upd
            else:
                upd = (extents_y - extents_x) / 2
                min_x -= upd
                max_x += upd
            padding_x = (max_x - min_x) * .1
            padding_y = (max_y - min_y) * .1

            vis_min_x = min_x - padding_x
            vis_max_x = max_x + padding_x
            vis_min_y = min_y - padding_y
            vis_max_y = max_y + padding_y

            markers_selected = projected_lbls.max(dim=1)[1]
            axis = local_axs[x][y]
            for marker_ind, marker in enumerate(markers):
                marker_projected = projected[markers_selected == marker_ind]
                marker_projected_lbls = projected_lbls[markers_selected ==
                                                       marker_ind]
                projected_colors = ots(marker_projected_lbls)
                axis.scatter(marker_projected[:, 0].numpy(),
                             marker_projected[:, 1].numpy(),
                             s=s,
                             alpha=alpha,
                             c=projected_colors.numpy(),
                             cmap=mcm.get_cmap(cmap),
                             norm=norm,
                             marker=marker)

            axis.set_xlim([vis_min_x, vis_max_x])
            axis.set_ylim([vis_min_y, vis_max_y])
            axis.tick_params(axis='both',
                             which='both',
                             bottom=False,
                             left=False,
                             top=False,
                             labelbottom=False,
                             labelleft=False)
            layer += 1

    local_path = os.path.join(outfile_wo_ext, 'local.png')
    local_fig.tight_layout()
    local_fig.savefig(local_path, transparent=transparent, DPI=DPI)

    np.savez(os.path.join(outfile_wo_ext, 'principal_vectors.npz'),
             *[snapshot.principal_vectors for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'principal_values.npz'),
             *[snapshot.principal_values for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_samples.npz'),
             *[snapshot.projected_samples for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_sample_labels.npz'),
             *[snapshot.projected_sample_labels for snapshot in traj])
Ejemplo n.º 9
0
    def save(self,
             filepath: str,
             exist_ok: bool = False,
             compress: bool = True) -> None:
        """Saves these clusters along with a description about how to load them
        to the given filepath. If the filepath has an extension, it must be .zip
        and it will be ignored in favor of compress.

        Arguments:
            filepath (str): the folder or zip file where these clusters should be
                saves
            exist_ok (bool): effects the behavior if the folder or zip file already
                exists. If this is False, then an error is thrown. If this is True,
                the existing files are deleted
            compress (bool): if True, the folder is compressed to a zip file after
                saving and the folder is deleted. If False, the result is left as a
                folder
        """

        outfile, outfile_wo_ext = mutils.process_outfile(
            filepath, exist_ok, compress)

        if os.path.exists(outfile_wo_ext):
            filetools.deldir(outfile_wo_ext)

        os.makedirs(outfile_wo_ext)

        np.savez_compressed(os.path.join(outfile_wo_ext, 'clusters.npz'),
                            samples=self.samples,
                            centers=self.centers,
                            labels=self.labels)

        with open(os.path.join(outfile_wo_ext, 'calculate_params.json'),
                  'w') as out:
            json.dump(self.calculate_params, out)

        with open(os.path.join(outfile_wo_ext, 'readme.md'), 'w') as out:

            def _print(*args, **kwargs):
                print(*args, **kwargs, file=out)

            _print('Clusters')
            _print('  clusters.npz:')
            _print(
                '    samples [n_samples, n_features] - the samples the clusters were calculated'
                + ' from')
            _print(
                '    centers [n_clusters, n_features] - the centers of the clusters'
            )
            _print(
                '    labels [n_samples] - the index in centers for the closest cluster '
                + 'to each label')
            _print('  calculate_params.json:')
            _print(
                '    Varies. Gives information about how clusters were calculated'
            )

        if compress:
            if os.path.exists(outfile):
                os.remove(outfile)
            filetools.zipdir(outfile_wo_ext)
Ejemplo n.º 10
0
def plot_trajectory(traj: PCTrajectoryGen, filepath: str, exist_ok: bool = False,
                    alpha: float = 0.5, square: bool = True, transparent: bool = True,
                    s: int = 1, ots: OutputToScalarMapping = SqueezeOTSMapping(),
                    cmap: typing.Union[mcolors.Colormap, str] = 'cividis',
                    norm: mcolors.Normalize = mcolors.Normalize(-1, 1),
                    compress: bool = False):
    """Plots the given trajectory by storing it in the given filepath. If the output of
    the trajectory is not itself a scalar, the output to scalar mapping must be set.
    The other arguments are related to display.

    Args:
        traj (PCTrajectoryGen): The trajectory to plot. Must have at least 2 pcs
        filepath (str): Where to store the given trajectory, either a folder or a zip file.
            The file zip extension will only be used if compress is true
        exist_ok (bool, optional): If the filepath already exists, then this determines if it
            should be overwritten (True) or an error should be raised (False). Defaults to False.
        alpha (float, optional): The transparency value for each vector. Defaults to 0.5.
        square (bool, optional): If the dimensions of the space should be equal for width and
            height (such that 1 inch width and height visually corresponds to the same amount of
            distance in pc-space). Since pc space is naturally rectangular, not setting this
            can easily lead to misinterpretations. Defaults to True.
        transparent (bool, optional): Determines the background color of the saved images, where
            True is transparency and False is near-white. Defaults to True.
        s (int, optional): The size of each projected sample. Defaults to 1.
        ots (OutputToScalarMapping, optional): Maps the labels of the trajectory to samples which
            are then converted to colors using the color map. Defaults to SqueezeOTSMapping().
        cmap (str or Colormap, optional): The color map to use. Defaults to 'cividis'.
        norm (mcolors.Normalize, optional): Normalizes the scalars that are passed to the color
            map to the range 0-1. Defaults to normalizing linearly from [-1, 1] to [0, 1]
        compress (bool): if the folder should be zipped
    """
    tus.check(
        traj=(traj, PCTrajectoryGen),
        filepath=(filepath, str),
        exist_ok=(exist_ok, bool),
        alpha=(alpha, float),
        square=(square, bool),
        transparent=(transparent, bool),
        s=(s, int),
        ots=(ots, OutputToScalarMapping),
        cmap=(cmap, (str, mcolors.Colormap))
    )

    outfile, outfile_wo_ext = mutils.process_outfile(filepath, exist_ok, compress)
    if not compress and exist_ok and os.path.exists(outfile_wo_ext):
        filetools.deldir(outfile_wo_ext)
    os.makedirs(outfile_wo_ext)

    num_splots_req = traj.num_layers + 1
    closest_square: int = int(np.ceil(np.sqrt(num_splots_req)))
    num_cols: int = int(math.ceil(num_splots_req / closest_square))
    local_fig, local_axs = plt.subplots(num_cols, closest_square, squeeze=False, figsize=FRAME_SIZE)

    layer: int = 0
    for x in range(num_cols):
        for y in range(closest_square):
            if layer >= num_splots_req:
                local_axs[x][y].remove()
                continue
            elif layer >= traj.num_layers:
                lspace = np.linspace(norm.vmin, norm.vmax, 100)
                axis = local_axs[x][y]
                axis.tick_params(axis='both', which='both', bottom=False, left=False, top=False,
                                 labelbottom=False, labelleft=False)
                axis.imshow(lspace[..., np.newaxis], cmap=cmap, norm=norm, aspect=0.2)
                layer += 1
                continue
            snapshot: PCTrajectoryGenSnapshot = traj[layer]

            projected = snapshot.projected_samples
            projected_lbls = snapshot.projected_sample_labels

            min_x, min_y, max_x, max_y = (torch.min(projected[:, 0]), torch.min(projected[:, 1]),
                                          torch.max(projected[:, 0]), torch.max(projected[:, 1]))
            min_x, min_y, max_x, max_y = min_x.item(), min_y.item(), max_x.item(), max_y.item()

            if max_x - min_x < 1e-3:
                min_x -= 5e-4
                max_x += 5e-4
            if max_y - min_y < 1e-3:
                min_y -= 5e-4
                max_y += 5e-4
            if square:
                extents_x = max_x - min_x
                extents_y = max_y - min_y
                if extents_x > extents_y:
                    upd = (extents_x - extents_y) / 2
                    min_y -= upd
                    max_y += upd
                else:
                    upd = (extents_y - extents_x) / 2
                    min_x -= upd
                    max_x += upd
            padding_x = (max_x - min_x) * .1
            padding_y = (max_y - min_y) * .1

            vis_min_x = min_x - padding_x
            vis_max_x = max_x + padding_x
            vis_min_y = min_y - padding_y
            vis_max_y = max_y + padding_y

            projected_colors = ots(projected_lbls)
            axis = local_axs[x][y]
            axis.scatter(projected[:, 0].numpy(), projected[:, 1].numpy(),
                         s=s, alpha=alpha, c=projected_colors.numpy(),
                         cmap=mcm.get_cmap(cmap), norm=norm)
            axis.set_xlim([vis_min_x, vis_max_x])
            axis.set_ylim([vis_min_y, vis_max_y])
            axis.tick_params(axis='both', which='both', bottom=False, left=False, top=False,
                             labelbottom=False, labelleft=False)
            layer += 1

    local_path = os.path.join(outfile_wo_ext, 'local.png')
    local_fig.tight_layout()
    local_fig.savefig(local_path, transparent=transparent, DPI=DPI)

    np.savez(os.path.join(outfile_wo_ext, 'principal_vectors.npz'),
             *[snapshot.principal_vectors for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'principal_values.npz'),
             *[snapshot.principal_values for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_samples.npz'),
             *[snapshot.projected_samples for snapshot in traj])
    np.savez(os.path.join(outfile_wo_ext, 'projected_sample_labels.npz'),
             *[snapshot.projected_sample_labels for snapshot in traj])

    if compress:
        if os.path.exists(outfile):
            os.remove(outfile)

        filetools.zipdir(outfile_wo_ext)
Ejemplo n.º 11
0
    def save(self, outpath: str, exist_ok: bool = False):
        """Saves this network to the given path. The path should be a folder, because inside
        the folder we will store:
            model.pt - an equivalent network which can be shared and loaded with just pytorch
            layers.npz - the fully connected layer weights and biases in numpy form
                Each fully connected layer i has its weights stored in lyr_weight_i and its biases
                stored in lyr_bias_i
            layers.mat - the fully connected layer weights and biases in matlab form
                Same variable names as layers.npz
            readme.txt - stores relevant documentation for loading this model
        """
        if os.path.exists(outpath):
            if not exist_ok:
                raise FileExistsError(outpath)
            if not os.path.isdir(outpath):
                raise ValueError(
                    f'expected outpath is directory, got {outpath} (not isdir)'
                )
            filetools.deldir(outpath)

        os.makedirs(outpath)

        equiv_net = torch.nn.Sequential(
            torch.nn.BatchNorm1d(ENCODE_DIM,
                                 affine=False,
                                 track_running_stats=False),
            cp_utils.copy_linear(self.fc_layers[0]), torch.nn.Tanh(),
            torch.nn.BatchNorm1d(HIDDEN_DIM,
                                 affine=False,
                                 track_running_stats=False),
            cp_utils.copy_linear(self.fc_layers[1]), torch.nn.Tanh(),
            torch.nn.BatchNorm1d(HIDDEN_DIM,
                                 affine=False,
                                 track_running_stats=False),
            cp_utils.copy_linear(self.fc_layers[2]), torch.nn.Tanh())
        torch.save(equiv_net, os.path.join(outpath, 'model.pt'))
        del equiv_net

        layer_data = {}
        for i, lyr in enumerate(self.fc_layers):
            layer_data[f'lyr_weight_{i}'] = lyr.weight.data.clone().numpy()
            layer_data[f'lyr_bias_{i}'] = lyr.bias.data.clone().numpy()
        np.savez_compressed(os.path.join(outpath, 'layers.npz'), **layer_data)
        scipy.io.savemat(os.path.join(outpath, 'layers.mat'), layer_data)

        with open(os.path.join(outpath, 'readme.txt'), 'w') as outfile:
            print('Model: Deep1ModelTrain', file=outfile)
            print(f'Date: {datetime.datetime.now()}', file=outfile)
            print('Constants:', file=outfile)
            for nm, const in {
                    'ALPHA': ALPHA,
                    'CUTOFF': CUTOFF,
                    'ENCODE_DIM': ENCODE_DIM,
                    'HIDDEN_DIM': HIDDEN_DIM
            }.items():
                print(f'  {nm}: {const}', file=outfile)
            print('Class Documentation:', file=outfile)
            print(Deep1ModelTrain.__doc__, file=outfile)
            print(file=outfile)
            print('Function Documentation: ', file=outfile)
            print(Deep1ModelTrain.save.__doc__, file=outfile)