Esempio n. 1
0
def train_svms(pwl_prod: PointWithLabelProducer,
               hidden_states: typing.Callable) -> SVMTrajectory:
    """Trains svms on each layer. This is invariant to feed-forward or
    recurrent layers; hidden_states should be a function which accepts
    a tensor [batch_size, input_dim] and returns an iterable of tensors
    [batch_size, layer_size]. Each layer size may be different, however
    each call to hidden_states must result in exactly the same number of
    layers in the same order.

    Args:
        pwl_prod (PointWithLabelProducer): the problem to show to the network
        hidden_states (typing.Callable): a function which gives an iterable of the
            hidden states when given a points tensor

    Returns:
        traj (SVMTrajectory): the svm accuracy through time/layers
    """

    num_points = min(pwl_prod.output_dim * 150, pwl_prod.epoch_size)

    sample_points = torch.zeros((num_points, pwl_prod.input_dim), dtype=torch.double)
    sample_labels = torch.zeros(num_points, dtype=torch.long)

    pwl_prod.fill(sample_points, sample_labels)

    hid_acts = tuple(state for state in hidden_states(sample_points))
    return train_svms_with(sample_points, sample_labels, *hid_acts)
Esempio n. 2
0
def find_trajectory(model: FeedforwardNetwork,
                    pwl_prod: PointWithLabelProducer,
                    num_pcs: int) -> PCTrajectoryFF:
    """Finds the pc trajectory for the given feed-forward model. Gets only the top
    num_pcs principal components

    Args:
        model (FeedforwardNetwork): the model to get the pc trajectory of
        pwl_prod (PointWithLabelProducer): the points to pass through
        num_pcs (int): the number of pcs to get

    Returns:
        PCTrajectoryFF: the pc trajectory of the network for the sampled points
    """
    if not isinstance(model, FeedforwardNetwork):
        raise ValueError(
            f'expected model is FeedforwardNetwork, got {model} (type={type(model)})'
        )
    if not isinstance(pwl_prod, PointWithLabelProducer):
        raise ValueError(
            f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})'
        )
    if not isinstance(num_pcs, int):
        raise ValueError(f'expected num_pcs is int, got {num_pcs}')
    if num_pcs <= 0:
        raise ValueError(f'expected num_pcs is positive, got {num_pcs}')

    num_samples = min(pwl_prod.epoch_size, 200 * pwl_prod.output_dim)
    sample_points = torch.zeros((num_samples, model.input_dim),
                                dtype=torch.double)
    sample_labels = torch.zeros((num_samples, ), dtype=torch.long)
    snapshots = []  # we will fill with PCTrajectoryFFSnapshot's

    pwl_prod.fill(sample_points, sample_labels)

    def on_hidacts(acts_info: FFHiddenActivations):
        hid_acts = acts_info.hidden_acts.detach()

        pc_vals, pc_vecs = get_hidden_pcs(hid_acts, num_pcs)
        projected = project_to_pcs(hid_acts, pc_vecs, out=None)
        snapshots.append(
            PCTrajectoryFFSnapshot(pc_vecs, pc_vals, projected, sample_labels))

    model(sample_points, on_hidacts)

    return PCTrajectoryFF(snapshots)
def measure_pr_gen(network: typing.Union[FeedforwardNetwork, NaturalRNN],
                   pwl_prod: PointWithLabelProducer,
                   points_dtype: typing.Any = torch.float,
                   out_dtype: typing.Any = torch.float,
                   recur_times: int = 10,
                   squeeze_to_pwl: bool = False) -> PRTrajectory:
    """Measures the participation ratio trajectory for the given generalized feed
    forward network, not assuming that the output dimension is a label. The
    resulting trajectory does not have by_label set.
    """
    tus.check(
        network=(network, (FeedforwardNetwork, NaturalRNN)),
        pwl_prod=(pwl_prod, PointWithLabelProducer),
    )

    num_samples = min(pwl_prod.epoch_size, 2000)
    points = torch.zeros((num_samples, pwl_prod.input_dim), dtype=points_dtype)
    out = torch.zeros((num_samples, network.output_dim), dtype=out_dtype)
    pwl_prod.fill(points, out.squeeze() if squeeze_to_pwl else out)

    num_layers = 0
    if isinstance(network, NaturalRNN):
        num_layers = recur_times
    else:
        num_layers = network.num_layers
    pr_overall = torch.zeros(num_layers + 1, dtype=points_dtype)

    def on_hidacts_raw(hid_acts: torch.tensor, layer: int):
        pr_overall[layer] = measure_pr(hid_acts)

    if isinstance(network, NaturalRNN):
        def on_hidacts_rnn(hacts: RNNHiddenActivations):
            on_hidacts_raw(hacts.hidden_acts.detach(), hacts.recur_step)
        network(points, recur_times, on_hidacts_rnn, 1)
    else:
        def on_hidacts_ff(facts: FFHiddenActivations):
            on_hidacts_raw(facts.hidden_acts.detach(), facts.layer)
        network(points, on_hidacts_ff)

    return PRTrajectory(pr_overall, None, True)
Esempio n. 4
0
def get_hidacts_rnn(network: NaturalRNN, pwl: PointWithLabelProducer, recur_times: int,
                    num_points: typing.Optional[int] = None) -> NetworkHiddenActivations:
    """Gets the hidden activations for the given recurrent network, acquiring at most
    num_points from the producer.

    Args:
        network (NaturalRNN): the network to forward through
        pwl (PointWithLabelProducer): the producer for points to run through
        num_points (typing.Optional[int], optional): Defaults to None. the maximum number
            of points to run through the network. Clipped to pwl.epoch_size

    Returns:
        NetworkHiddenActivations: the internal activations of the network
    """
    if num_points is None:
        num_points = 50*network.output_dim

    if not isinstance(network, NaturalRNN):
        raise ValueError(f'expected network is FeedforwardNetwork, got {network} (type={type(network)})')
    if not isinstance(pwl, PointWithLabelProducer):
        raise ValueError(f'expected pwl is PointWithLabelProducer, got {pwl} (type={type(pwl)})')
    if not isinstance(num_points, int):
        raise ValueError(f'expected num_points is int, got {num_points} (type={type(num_points)})')
    if not isinstance(recur_times, int):
        raise ValueError(f'expected recur_times is int, got {recur_times} (type={type(recur_times)})')

    sample_points = torch.zeros((num_points, pwl.input_dim), dtype=torch.double)
    sample_labels = torch.zeros(num_points, dtype=torch.int)

    pwl.mark()
    pwl.fill(sample_points, sample_labels)
    pwl.reset()
    return get_hidacts_rnn_with_sample(network, sample_points, sample_labels, recur_times)
Esempio n. 5
0
def get_hidacts_ff(network: FeedforwardNetwork, pwl: PointWithLabelProducer,
                   num_points: typing.Optional[int] = None) -> NetworkHiddenActivations:
    """Creates a sample of at most num_points from the given point with label producer without
    affecting its internal state. Then runs those through the network and returns the hidden
    activations that came out of that

    Arguments:
        network (FeedforwardNetwork): the network which the points should be run through
        pwl (PointWithLabelProducer): where to acquire the points
        num_points (int, optional): if specified, the number of points to fetch. if not specified
            this is min(constant*network.num_layers, pwl.epoch_size) where the constant is
            reasonable of the order of 100
    """

    if num_points is None:
        num_points = 50*network.output_dim

    if not isinstance(network, FeedforwardNetwork):
        raise ValueError(f'expected network is FeedforwardNetwork, got {network} (type={type(network)})')
    if not isinstance(pwl, PointWithLabelProducer):
        raise ValueError(f'expected pwl is PointWithLabelProducer, got {pwl} (type={type(pwl)})')
    if not isinstance(num_points, int):
        raise ValueError(f'expected num_points is int, got {num_points} (type={type(num_points)})')

    num_points = min(num_points, pwl.epoch_size)

    sample_points = torch.zeros((num_points, pwl.input_dim), dtype=torch.double)
    sample_labels = torch.zeros(num_points, dtype=torch.int)

    pwl.mark()
    pwl.fill(sample_points, sample_labels)
    pwl.reset()
    return get_hidacts_ff_with_sample(network, sample_points, sample_labels)
def measure_pr_rnn(network: NaturalRNN, teacher: RNNTeacher, pwl_prod: PointWithLabelProducer) -> PRTrajectory:
    """Measures the participation ratio through time for the given recurrent network

    Args:
        network (NaturalRNN): the network to measure
        teacher (RNNTeacher): the teacher for the network
        pwl_prod (PointWithLabelProducer): the point with label producer to generate points with

    Returns:
        PRTrajectory: the participation ratio trajectory for the network
    """
    if not isinstance(network, NaturalRNN):
        raise ValueError(f'expected network is a NaturalRNN, got {network} (type={type(network)})')
    if not isinstance(pwl_prod, PointWithLabelProducer):
        raise ValueError(f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})')

    num_samples = min(pwl_prod.epoch_size, 100 * pwl_prod.output_dim)
    sample_points = torch.zeros((num_samples, network.input_dim), dtype=torch.double)
    sample_labels = torch.zeros((num_samples,), dtype=torch.long)

    pr_overall = torch.zeros(teacher.recurrent_times + 1, dtype=torch.double)
    pr_by_label = torch.zeros((pwl_prod.output_dim, teacher.recurrent_times + 1), dtype=torch.double)

    pwl_prod.fill(sample_points, sample_labels)

    masks_by_lbl = [sample_labels == lbl for lbl in range(pwl_prod.output_dim)]

    def on_hidacts(acts_info: RNNHiddenActivations):
        hid_acts = acts_info.hidden_acts.detach()
        layer = acts_info.layer

        pr_overall[layer] = measure_pr(hid_acts)
        for lbl in range(pwl_prod.output_dim):
            pr_by_label[lbl, layer] = measure_pr(hid_acts[masks_by_lbl[lbl]])

    network(sample_points, on_hidacts)

    return PRTrajectory(overall=pr_overall, by_label=pr_by_label, layers=False)
def measure_pr_ff(network: FeedforwardNetwork, pwl_prod: PointWithLabelProducer) -> PRTrajectory:
    """Measures the participation ratio through layers for the given feedforward network

    Args:
        network (FeedforwardNetwork): The feedforward network to measure pr through layers of
        pwl_prod (PointWithLabelProducer): The pointproducer to sample points from

    Returns:
        traj (PRTrajectory): the trajectory of the networks participation ratio
    """
    if not isinstance(network, FeedforwardNetwork):
        raise ValueError(f'expected network is FeedforwardNetwork, got {network} (type={type(network)})')
    if not isinstance(pwl_prod, PointWithLabelProducer):
        raise ValueError(f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})')

    num_samples = min(pwl_prod.epoch_size, 100 * pwl_prod.output_dim)
    sample_points = torch.zeros((num_samples, network.input_dim), dtype=torch.double)
    sample_labels = torch.zeros((num_samples,), dtype=torch.long)

    pr_overall = torch.zeros(network.num_layers+1, dtype=torch.double)
    pr_by_label = torch.zeros((pwl_prod.output_dim, network.num_layers+1), dtype=torch.double)

    pwl_prod.fill(sample_points, sample_labels)

    masks_by_lbl = [sample_labels == lbl for lbl in range(pwl_prod.output_dim)]

    def on_hidacts(acts_info: FFHiddenActivations):
        hid_acts = acts_info.hidden_acts.detach()
        layer = acts_info.layer

        pr_overall[layer] = measure_pr(hid_acts)
        for lbl in range(pwl_prod.output_dim):
            pr_by_label[lbl, layer] = measure_pr(hid_acts[masks_by_lbl[lbl]])

    network(sample_points, on_hidacts)

    return PRTrajectory(overall=pr_overall, by_label=pr_by_label, layers=True)
Esempio n. 8
0
def find_trajectory(model: Network, pwl_prod: PointWithLabelProducer,
                    num_pcs: int, recur_times: int = 10,
                    points_dtype: typing.Any = torch.float,
                    out_dtype: typing.Any = torch.float,
                    squeeze_to_pwl: bool = True) -> PCTrajectoryGen:
    """Finds the PC trajectory for the given network, sampling a reasonable number of
    points from the dataset. The resulting labels will simply be the shape of the output
    layer for the network.

    Args:
        model (Network): the network whose activations will be fetched
        pwl_prod (PointWithLabelProducer): the dataset to sample inputs from
        num_pcs (int): the number of principal components to project onto
        recur_times (int, optional): if the network is recurrent, this is how many timesteps
            to get the trajectory on. Default 10.
        points_dtype (torch dtype, optional): the data type for points. Default torch.float.
        out_dtype (torch dtype, optional): the data type to use. Default torch.float.
        squeeze_to_pwl (bool, optional): If the output is squeezed before filled by the point
            with label producer. Helpful if the output is a single scalar and the pwl expects
            a tensor with shape (num_samples). Default True
    """
    tus.check(
        model=(model, Network),
        pwl_prod=(pwl_prod, PointWithLabelProducer),
        num_pcs=(num_pcs, int),
        recur_times=(recur_times, int)
    )

    num_samples = min(pwl_prod.epoch_size, 2000)
    points = torch.zeros((num_samples, pwl_prod.input_dim), dtype=points_dtype)
    out = torch.zeros((num_samples, model.output_dim), dtype=out_dtype)
    pwl_prod.fill(points, out.squeeze() if squeeze_to_pwl else out)

    snapshots = []

    def on_hidacts_raw(hid_acts: torch.tensor):
        if hid_acts.shape[1] >= num_pcs:
            pc_vals, pc_vecs = pca.get_hidden_pcs(hid_acts, num_pcs)
            projected = pca.project_to_pcs(hid_acts, pc_vecs, out=None)
            snapshots.append(PCTrajectoryGenSnapshot(pc_vecs, pc_vals, projected, out))
        elif hid_acts.shape[1] > 1:
            pc_vals, pc_vecs = pca.get_hidden_pcs(hid_acts, hid_acts.shape[1])
            projected = pca.project_to_pcs(hid_acts, pc_vecs, out=None)

            pc_vecs_app = torch.zeros((num_pcs, hid_acts.shape[1]), dtype=pc_vecs.dtype)
            pc_vecs_app[:pc_vecs.shape[0]] = pc_vecs

            pc_vals_app = torch.zeros((num_pcs,), dtype=pc_vals.dtype)
            pc_vals_app[:pc_vals.shape[0]] = pc_vals

            projected_app = torch.zeros((hid_acts.shape[0], num_pcs), dtype=projected.dtype)
            projected_app[:projected.shape[0]] = projected

            snapshots.append(PCTrajectoryGenSnapshot(pc_vecs_app, pc_vals_app, projected_app, out))
        else:
            pc_vecs = torch.zeros((num_pcs, hid_acts.shape[1]), dtype=hid_acts.dtype)
            pc_vals = torch.zeros((num_pcs,), dtype=hid_acts.dtype)

            pc_vals[0] = 1
            pc_vecs[0, 0] = 1

            projected = torch.zeros((hid_acts.shape[0], num_pcs), dtype=hid_acts.dtype)
            projected[:, 0] = hid_acts.squeeze()

            snapshots.append(PCTrajectoryGenSnapshot(pc_vecs, pc_vals, projected, out))


    if isinstance(model, NaturalRNN):
        def on_hidacts_rnn(hacts: RNNHiddenActivations):
            on_hidacts_raw(hacts.hidden_acts.detach())
        model(points, recur_times, on_hidacts_rnn, 1)
    else:
        def on_hidacts_ff(facts: FFHiddenActivations):
            on_hidacts_raw(facts.hidden_acts.detach())
        model(points, on_hidacts_ff)

    return PCTrajectoryGen(snapshots)
Esempio n. 9
0
def find_trajectory(model: NaturalRNN, pwl_prod: PointWithLabelProducer,
                    duration: int, num_pcs: int) -> PCTrajectory:
    """Finds the trajectory of the given model using the given point with label producer. Goes
    through the entire epoch for the point with label producer.

    Args:
        model (Natural): The underlying model whose trajectories are being considered
        pwl_prod (PointWithLabelProducer): The producer for the samples. The entire epoch
            is gone through
        duration (int): How many timesteps to go through
        num_pcs (int): The number of principal vectors to find
    """

    num_samples = min(pwl_prod.epoch_size, 100 * pwl_prod.output_dim)
    sample_points = torch.zeros((num_samples, model.input_dim),
                                dtype=torch.double)
    sample_labels = torch.zeros((num_samples, ), dtype=torch.long)
    hid_acts = torch.zeros((duration + 1, num_samples, model.hidden_dim),
                           dtype=torch.double)
    hid_pc_vals = torch.zeros((duration + 1, num_pcs), dtype=torch.double)
    hid_pc_vecs = torch.zeros((duration + 1, num_pcs, model.hidden_dim),
                              dtype=torch.double)
    proj_samples = torch.zeros((duration + 1, num_samples, num_pcs),
                               dtype=torch.double)

    pwl_prod.fill(sample_points, sample_labels)

    def on_hidacts(acts_info: RNNHiddenActivations):
        hidden_acts = acts_info.hidden_acts
        recur_step = acts_info.recur_step

        hid_acts[recur_step, :, :] = hidden_acts.detach()
        pc_vals, pc_vecs = get_hidden_pcs(hid_acts[recur_step], num_pcs)
        hid_pc_vals[recur_step, :] = pc_vals
        hid_pc_vecs[recur_step, :, :] = pc_vecs

    model(sample_points, duration, on_hidacts, 1)

    for recur_step in range(duration + 1):
        project_to_pcs(hid_acts[recur_step],
                       hid_pc_vecs[recur_step],
                       out=proj_samples[recur_step])

    # We are free to rotate the pc vectors as we please. The following rotates
    # them such that the mean value of each label on each pc stays on the same
    # side. This is only 100% accomplishable for 2 labels since we must swap
    # the direction for ALL labels for a particular pc. we ensure that at least
    # 50% of the labels did not change direction from the start
    indices_by_label = dict(
        (lbl, sample_labels == lbl) for lbl in range(pwl_prod.output_dim))

    means_by_label_and_recur = dict()
    for lbl in range(pwl_prod.output_dim):
        for recur_step in range(duration + 1):
            means_by_label_and_recur[(lbl, recur_step)] = (
                proj_samples[recur_step, indices_by_label[lbl], :].mean(0))

    # we can flip any of our pcs
    for recur_step in range(1, duration + 1):
        for pc in range(num_pcs):  # pylint: disable=invalid-name
            badness = 0
            counter = 0
            for lbl1 in range(pwl_prod.output_dim):
                for lbl2 in range(lbl1 + 1, pwl_prod.output_dim):
                    used_to_be_lt = (means_by_label_and_recur[(lbl1, 0)][pc] <
                                     means_by_label_and_recur[(lbl2, 0)][pc])
                    curr_is_lt = (
                        means_by_label_and_recur[(lbl1, recur_step)][pc] <
                        means_by_label_and_recur[(lbl2, recur_step)][pc])
                    counter += 1
                    if used_to_be_lt != curr_is_lt:
                        badness += 1

            if badness >= (counter / 2):
                hid_pc_vecs[recur_step, pc, :] *= -1

    for recur_step in range(duration + 1):
        project_to_pcs(hid_acts[recur_step],
                       hid_pc_vecs[recur_step],
                       out=proj_samples[recur_step])

    return PCTrajectory(hid_pc_vecs, hid_pc_vals, proj_samples, sample_labels,
                        duration)
Esempio n. 10
0
def measure_dtt(model: NaturalRNN, pwl_prod: PointWithLabelProducer,
                duration: int, outfile: str, exist_ok: bool = False,
                logger: logging.Logger = None, verbose: bool = False) -> None:
    """Measures the distance of points in hidden activation space from points which
    are sampled from different labels. For example, points from the same label might
    be close in hidden activation space even when they are far in input space.

    Args:
        model (NaturalRNN): the model
        pwl_prod (PointWithLabelProducer): the points to test
        duration (int): the number of recurrent times
        outfile (str): where to save the result. Should have extension '.zip' or
                       no extension at all
        exist_ok (bool, default false): true if we should overwrite the outfile if it already
            exists, false to check if it exists and error if it does
    """

    if not isinstance(model, NaturalRNN):
        raise ValueError(f'expected model is NaturalRNN, got {model} (type={type(model)})')
    if not isinstance(pwl_prod, PointWithLabelProducer):
        raise ValueError(f'expected pwl is PointWithLabelProducer, got {pwl_prod} (type=({type(pwl_prod)})')
    if not isinstance(outfile, str):
        raise ValueError(f'expected outfile is str, got {outfile} (type={type(outfile)})')

    outfile_wo_ext = os.path.splitext(outfile)[0]
    if outfile_wo_ext == outfile:
        outfile = outfile_wo_ext + '.zip'

    if os.path.exists(outfile_wo_ext):
        raise FileExistsError(f'for outfile={outfile}, need {outfile_wo_ext} as working space')
    if not exist_ok and os.path.exists(outfile):
        raise FileExistsError(f'outfile {outfile} already exists (use exist_ok=True) to overwrite')


    num_samples = min(pwl_prod.epoch_size, 50 * pwl_prod.output_dim)
    sample_points = torch.zeros((num_samples, model.input_dim), dtype=torch.double)
    sample_labels = torch.zeros((num_samples,), dtype=torch.long)
    hid_acts = torch.zeros((duration+1, num_samples, model.hidden_dim), dtype=torch.double)
    within_dists = [] # each value corresponds to a torch tensor of within dists
    within_means = torch.zeros(duration+1, dtype=torch.double)
    within_stds = torch.zeros(duration+1, dtype=torch.double)
    within_sems = torch.zeros(duration+1, dtype=torch.double)
    across_dists = [] # each value corresponds to a torch tensor of across dists
    across_means = torch.zeros(duration+1, dtype=torch.double)
    across_stds = torch.zeros(duration+1, dtype=torch.double)
    across_sems = torch.zeros(duration+1, dtype=torch.double)

    pwl_prod.fill(sample_points, sample_labels)

    def on_hidacts(acts_info: RNNHiddenActivations):
        hidden_acts = acts_info.hidden_acts
        recur_step = acts_info.recur_step

        hid_acts[recur_step, :, :] = hidden_acts.detach()

        within, across = measure_instant(hid_acts[recur_step], sample_labels, pwl_prod.output_dim)
        within_dists.append(within)
        across_dists.append(across)

        within_means[recur_step] = within.mean()
        within_stds[recur_step] = within.std()
        within_sems[recur_step] = within_stds[recur_step] / np.sqrt(num_samples)

        across_means[recur_step] = across.mean()
        across_stds[recur_step] = across.std()
        across_sems[recur_step] = across_stds[recur_step] / np.sqrt(num_samples)

    _dbg(verbose, logger, 'measure_dtt getting raw data')
    model(sample_points, duration, on_hidacts, 1)


    within_col, across_col = 'tab:cyan', 'r'

    fig_mean_with_stddev, ax_mean_with_stddev = plt.subplots()
    fig_mean_with_sem, ax_mean_with_sem = plt.subplots()
    fig_mean_with_scatter, ax_mean_with_scatter = plt.subplots()

    ax_mean_with_stddev.set_title('Distances Through Time (error: 1.96 std dev)')
    ax_mean_with_sem.set_title('Distances Through Time (error: 1.96 sem)')
    ax_mean_with_scatter.set_title('Distances Through Time')

    for ax in (ax_mean_with_stddev, ax_mean_with_sem, ax_mean_with_scatter):
        ax.set_xlabel('Time (recurrent steps occurred)')
        ax.set_ylabel('Distance (euclidean)')

    recur_steps = np.arange(duration+1)
    _dbg(verbose, logger, 'measure_dtt plotting mean_with_stddev')
    ax_mean_with_stddev.errorbar(recur_steps, within_means.numpy(), within_stds.numpy() * 1.96, color=within_col, label='Within')
    ax_mean_with_stddev.errorbar(recur_steps, across_means.numpy(), across_stds.numpy() * 1.96, color=across_col, label='Across')
    _dbg(verbose, logger, 'measure_dtt plotting mean_with_sem')
    ax_mean_with_sem.errorbar(recur_steps, within_means.numpy(), within_sems.numpy() * 1.96, color=within_col, label='Within')
    ax_mean_with_sem.errorbar(recur_steps, across_means.numpy(), across_sems.numpy() * 1.96, color=across_col, label='Across')
    _dbg(verbose, logger, 'measure_dtt plotting mean_with_scatter')
    ax_mean_with_scatter.plot(recur_steps, within_means.numpy(), color=within_col, label='Within')
    ax_mean_with_scatter.plot(recur_steps, across_means.numpy(), color=across_col, label='Across')

    for recur_step in range(duration+1):
        xvals = np.zeros(within_dists[recur_step].shape, dtype='uint8') + recur_step
        ax_mean_with_scatter.scatter(xvals, within_dists[recur_step], 1, within_col, alpha=0.3)
        xvals = np.zeros(across_dists[recur_step].shape, dtype='uint8') + recur_step
        ax_mean_with_scatter.scatter(xvals, across_dists[recur_step], 1, across_col, alpha=0.3)

    _dbg(verbose, logger, 'measure_dtt saving and cleaning up')
    for ax in (ax_mean_with_stddev, ax_mean_with_sem, ax_mean_with_scatter):
        ax.legend()
        ax.set_xticks(recur_steps)

    for fig in (fig_mean_with_stddev, fig_mean_with_sem, fig_mean_with_scatter):
        fig.tight_layout()

    os.makedirs(outfile_wo_ext)

    fig_mean_with_stddev.savefig(os.path.join(outfile_wo_ext, 'mean_with_stddev.png'), transparent=True)
    fig_mean_with_sem.savefig(os.path.join(outfile_wo_ext, 'mean_with_sem.png'), transparent=True)
    fig_mean_with_scatter.savefig(os.path.join(outfile_wo_ext, 'mean_with_scatter.png'), transparent=True)

    plt.close(fig_mean_with_stddev)
    plt.close(fig_mean_with_sem)
    plt.close(fig_mean_with_scatter)

    np.savez(os.path.join(outfile_wo_ext, 'data.npz'),
        sample_points=sample_points.numpy(),
        sample_labels=sample_labels.numpy(),
        hid_acts=hid_acts.numpy()
        )
    np.savez(os.path.join(outfile_wo_ext, 'within.npz'), *tuple(wd.numpy() for wd in within_dists))
    np.savez(os.path.join(outfile_wo_ext, 'across.npz'), *tuple(ad.numpy() for ad in across_dists))
    np.savetxt(os.path.join(outfile_wo_ext, 'within_means.txt'), within_means.numpy())
    np.savetxt(os.path.join(outfile_wo_ext, 'across_means.txt'), across_means.numpy())


    if os.path.exists(outfile):
        os.remove(outfile)

    cwd = os.getcwd()
    shutil.make_archive(outfile_wo_ext, 'zip', outfile_wo_ext)
    os.chdir(cwd)
    shutil.rmtree(outfile_wo_ext)
    os.chdir(cwd)
Esempio n. 11
0
def measure_dtt_ff(model: FeedforwardNetwork, pwl_prod: PointWithLabelProducer,
                   outfile: str, exist_ok: bool = False,
                   logger: typing.Optional[logging.Logger] = None,
                   verbose: bool = False) -> None:
    """Analogue to measure_dtt for feed-forward networks"""
    if not isinstance(model, FeedforwardNetwork):
        raise ValueError(f'expected model is FeedforwardNetwork, got {model} (type={type(model)})')
    if not isinstance(pwl_prod, PointWithLabelProducer):
        raise ValueError(f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})')
    if not isinstance(outfile, str):
        raise ValueError(f'expected outfile is str, got {outfile}')
    if not isinstance(exist_ok, bool):
        raise ValueError(f'expected exist_ok is bool, got {exist_ok}')
    if logger is not None and not isinstance(logger, logging.Logger):
        raise ValueError(f'expected logger is optional[logging.Logger], got {logger} (type={type(logger)})')
    if not isinstance(verbose, bool):
        raise ValueError(f'expected verbose is bool, got {verbose} (type={type(verbose)})')

    outfile_wo_ext = os.path.splitext(outfile)[0]
    if outfile_wo_ext == outfile:
        outfile = outfile_wo_ext + '.zip'

    if os.path.exists(outfile_wo_ext):
        raise FileExistsError(f'for outfile={outfile}, need {outfile_wo_ext} as working space')
    if not exist_ok and os.path.exists(outfile):
        raise FileExistsError(f'outfile {outfile} already exists (use exist_ok=True) to overwrite')

    num_samples = min(pwl_prod.epoch_size, 50 * pwl_prod.output_dim)

    sample_points = torch.zeros((num_samples, model.input_dim), dtype=torch.double)
    sample_labels = torch.zeros((num_samples,), dtype=torch.long)
    hid_acts = [] # each will be 2d tensor
    within_dists = [] # each value corresponds to a torch tensor of within dists
    within_means = torch.zeros(model.num_layers+1, dtype=torch.double)
    within_stds = torch.zeros(model.num_layers+1, dtype=torch.double)
    within_sems = torch.zeros(model.num_layers+1, dtype=torch.double)
    across_dists = [] # each value corresponds to a torch tensor of across dists
    across_means = torch.zeros(model.num_layers+1, dtype=torch.double)
    across_stds = torch.zeros(model.num_layers+1, dtype=torch.double)
    across_sems = torch.zeros(model.num_layers+1, dtype=torch.double)

    pwl_prod.mark()
    pwl_prod.fill(sample_points, sample_labels)
    pwl_prod.reset()

    def on_hidacts(acts_info: FFHiddenActivations):
        hidden_acts = acts_info.hidden_acts
        layer = acts_info.layer

        hid_acts.append(hidden_acts.detach())

        within, across = measure_instant(hid_acts[layer], sample_labels, pwl_prod.output_dim)
        within_dists.append(within)
        across_dists.append(across)

        within_means[layer] = within.mean()
        within_stds[layer] = within.std()
        within_sems[layer] = within_stds[layer] / np.sqrt(num_samples)

        across_means[layer] = across.mean()
        across_stds[layer] = across.std()
        across_sems[layer] = across_stds[layer] / np.sqrt(num_samples)

    _dbg(verbose, logger, 'measure_dtt_ff getting raw data')
    model(sample_points, on_hidacts)

    layers = np.arange(model.num_layers+1)

    _plot_dtt_ff(layers, within_means, within_stds, within_sems,
                 across_means, across_stds, across_sems,
                 within_dists, across_dists, outfile_wo_ext,
                 verbose, logger)
    _save_dtt_ff(sample_points, sample_labels, hid_acts,
                 within_dists, across_dists, outfile_wo_ext)

    if os.path.exists(outfile):
        os.remove(outfile)

    zipdir(outfile_wo_ext)