def train_svms(pwl_prod: PointWithLabelProducer, hidden_states: typing.Callable) -> SVMTrajectory: """Trains svms on each layer. This is invariant to feed-forward or recurrent layers; hidden_states should be a function which accepts a tensor [batch_size, input_dim] and returns an iterable of tensors [batch_size, layer_size]. Each layer size may be different, however each call to hidden_states must result in exactly the same number of layers in the same order. Args: pwl_prod (PointWithLabelProducer): the problem to show to the network hidden_states (typing.Callable): a function which gives an iterable of the hidden states when given a points tensor Returns: traj (SVMTrajectory): the svm accuracy through time/layers """ num_points = min(pwl_prod.output_dim * 150, pwl_prod.epoch_size) sample_points = torch.zeros((num_points, pwl_prod.input_dim), dtype=torch.double) sample_labels = torch.zeros(num_points, dtype=torch.long) pwl_prod.fill(sample_points, sample_labels) hid_acts = tuple(state for state in hidden_states(sample_points)) return train_svms_with(sample_points, sample_labels, *hid_acts)
def find_trajectory(model: FeedforwardNetwork, pwl_prod: PointWithLabelProducer, num_pcs: int) -> PCTrajectoryFF: """Finds the pc trajectory for the given feed-forward model. Gets only the top num_pcs principal components Args: model (FeedforwardNetwork): the model to get the pc trajectory of pwl_prod (PointWithLabelProducer): the points to pass through num_pcs (int): the number of pcs to get Returns: PCTrajectoryFF: the pc trajectory of the network for the sampled points """ if not isinstance(model, FeedforwardNetwork): raise ValueError( f'expected model is FeedforwardNetwork, got {model} (type={type(model)})' ) if not isinstance(pwl_prod, PointWithLabelProducer): raise ValueError( f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})' ) if not isinstance(num_pcs, int): raise ValueError(f'expected num_pcs is int, got {num_pcs}') if num_pcs <= 0: raise ValueError(f'expected num_pcs is positive, got {num_pcs}') num_samples = min(pwl_prod.epoch_size, 200 * pwl_prod.output_dim) sample_points = torch.zeros((num_samples, model.input_dim), dtype=torch.double) sample_labels = torch.zeros((num_samples, ), dtype=torch.long) snapshots = [] # we will fill with PCTrajectoryFFSnapshot's pwl_prod.fill(sample_points, sample_labels) def on_hidacts(acts_info: FFHiddenActivations): hid_acts = acts_info.hidden_acts.detach() pc_vals, pc_vecs = get_hidden_pcs(hid_acts, num_pcs) projected = project_to_pcs(hid_acts, pc_vecs, out=None) snapshots.append( PCTrajectoryFFSnapshot(pc_vecs, pc_vals, projected, sample_labels)) model(sample_points, on_hidacts) return PCTrajectoryFF(snapshots)
def measure_pr_gen(network: typing.Union[FeedforwardNetwork, NaturalRNN], pwl_prod: PointWithLabelProducer, points_dtype: typing.Any = torch.float, out_dtype: typing.Any = torch.float, recur_times: int = 10, squeeze_to_pwl: bool = False) -> PRTrajectory: """Measures the participation ratio trajectory for the given generalized feed forward network, not assuming that the output dimension is a label. The resulting trajectory does not have by_label set. """ tus.check( network=(network, (FeedforwardNetwork, NaturalRNN)), pwl_prod=(pwl_prod, PointWithLabelProducer), ) num_samples = min(pwl_prod.epoch_size, 2000) points = torch.zeros((num_samples, pwl_prod.input_dim), dtype=points_dtype) out = torch.zeros((num_samples, network.output_dim), dtype=out_dtype) pwl_prod.fill(points, out.squeeze() if squeeze_to_pwl else out) num_layers = 0 if isinstance(network, NaturalRNN): num_layers = recur_times else: num_layers = network.num_layers pr_overall = torch.zeros(num_layers + 1, dtype=points_dtype) def on_hidacts_raw(hid_acts: torch.tensor, layer: int): pr_overall[layer] = measure_pr(hid_acts) if isinstance(network, NaturalRNN): def on_hidacts_rnn(hacts: RNNHiddenActivations): on_hidacts_raw(hacts.hidden_acts.detach(), hacts.recur_step) network(points, recur_times, on_hidacts_rnn, 1) else: def on_hidacts_ff(facts: FFHiddenActivations): on_hidacts_raw(facts.hidden_acts.detach(), facts.layer) network(points, on_hidacts_ff) return PRTrajectory(pr_overall, None, True)
def get_hidacts_rnn(network: NaturalRNN, pwl: PointWithLabelProducer, recur_times: int, num_points: typing.Optional[int] = None) -> NetworkHiddenActivations: """Gets the hidden activations for the given recurrent network, acquiring at most num_points from the producer. Args: network (NaturalRNN): the network to forward through pwl (PointWithLabelProducer): the producer for points to run through num_points (typing.Optional[int], optional): Defaults to None. the maximum number of points to run through the network. Clipped to pwl.epoch_size Returns: NetworkHiddenActivations: the internal activations of the network """ if num_points is None: num_points = 50*network.output_dim if not isinstance(network, NaturalRNN): raise ValueError(f'expected network is FeedforwardNetwork, got {network} (type={type(network)})') if not isinstance(pwl, PointWithLabelProducer): raise ValueError(f'expected pwl is PointWithLabelProducer, got {pwl} (type={type(pwl)})') if not isinstance(num_points, int): raise ValueError(f'expected num_points is int, got {num_points} (type={type(num_points)})') if not isinstance(recur_times, int): raise ValueError(f'expected recur_times is int, got {recur_times} (type={type(recur_times)})') sample_points = torch.zeros((num_points, pwl.input_dim), dtype=torch.double) sample_labels = torch.zeros(num_points, dtype=torch.int) pwl.mark() pwl.fill(sample_points, sample_labels) pwl.reset() return get_hidacts_rnn_with_sample(network, sample_points, sample_labels, recur_times)
def get_hidacts_ff(network: FeedforwardNetwork, pwl: PointWithLabelProducer, num_points: typing.Optional[int] = None) -> NetworkHiddenActivations: """Creates a sample of at most num_points from the given point with label producer without affecting its internal state. Then runs those through the network and returns the hidden activations that came out of that Arguments: network (FeedforwardNetwork): the network which the points should be run through pwl (PointWithLabelProducer): where to acquire the points num_points (int, optional): if specified, the number of points to fetch. if not specified this is min(constant*network.num_layers, pwl.epoch_size) where the constant is reasonable of the order of 100 """ if num_points is None: num_points = 50*network.output_dim if not isinstance(network, FeedforwardNetwork): raise ValueError(f'expected network is FeedforwardNetwork, got {network} (type={type(network)})') if not isinstance(pwl, PointWithLabelProducer): raise ValueError(f'expected pwl is PointWithLabelProducer, got {pwl} (type={type(pwl)})') if not isinstance(num_points, int): raise ValueError(f'expected num_points is int, got {num_points} (type={type(num_points)})') num_points = min(num_points, pwl.epoch_size) sample_points = torch.zeros((num_points, pwl.input_dim), dtype=torch.double) sample_labels = torch.zeros(num_points, dtype=torch.int) pwl.mark() pwl.fill(sample_points, sample_labels) pwl.reset() return get_hidacts_ff_with_sample(network, sample_points, sample_labels)
def measure_pr_rnn(network: NaturalRNN, teacher: RNNTeacher, pwl_prod: PointWithLabelProducer) -> PRTrajectory: """Measures the participation ratio through time for the given recurrent network Args: network (NaturalRNN): the network to measure teacher (RNNTeacher): the teacher for the network pwl_prod (PointWithLabelProducer): the point with label producer to generate points with Returns: PRTrajectory: the participation ratio trajectory for the network """ if not isinstance(network, NaturalRNN): raise ValueError(f'expected network is a NaturalRNN, got {network} (type={type(network)})') if not isinstance(pwl_prod, PointWithLabelProducer): raise ValueError(f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})') num_samples = min(pwl_prod.epoch_size, 100 * pwl_prod.output_dim) sample_points = torch.zeros((num_samples, network.input_dim), dtype=torch.double) sample_labels = torch.zeros((num_samples,), dtype=torch.long) pr_overall = torch.zeros(teacher.recurrent_times + 1, dtype=torch.double) pr_by_label = torch.zeros((pwl_prod.output_dim, teacher.recurrent_times + 1), dtype=torch.double) pwl_prod.fill(sample_points, sample_labels) masks_by_lbl = [sample_labels == lbl for lbl in range(pwl_prod.output_dim)] def on_hidacts(acts_info: RNNHiddenActivations): hid_acts = acts_info.hidden_acts.detach() layer = acts_info.layer pr_overall[layer] = measure_pr(hid_acts) for lbl in range(pwl_prod.output_dim): pr_by_label[lbl, layer] = measure_pr(hid_acts[masks_by_lbl[lbl]]) network(sample_points, on_hidacts) return PRTrajectory(overall=pr_overall, by_label=pr_by_label, layers=False)
def measure_pr_ff(network: FeedforwardNetwork, pwl_prod: PointWithLabelProducer) -> PRTrajectory: """Measures the participation ratio through layers for the given feedforward network Args: network (FeedforwardNetwork): The feedforward network to measure pr through layers of pwl_prod (PointWithLabelProducer): The pointproducer to sample points from Returns: traj (PRTrajectory): the trajectory of the networks participation ratio """ if not isinstance(network, FeedforwardNetwork): raise ValueError(f'expected network is FeedforwardNetwork, got {network} (type={type(network)})') if not isinstance(pwl_prod, PointWithLabelProducer): raise ValueError(f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})') num_samples = min(pwl_prod.epoch_size, 100 * pwl_prod.output_dim) sample_points = torch.zeros((num_samples, network.input_dim), dtype=torch.double) sample_labels = torch.zeros((num_samples,), dtype=torch.long) pr_overall = torch.zeros(network.num_layers+1, dtype=torch.double) pr_by_label = torch.zeros((pwl_prod.output_dim, network.num_layers+1), dtype=torch.double) pwl_prod.fill(sample_points, sample_labels) masks_by_lbl = [sample_labels == lbl for lbl in range(pwl_prod.output_dim)] def on_hidacts(acts_info: FFHiddenActivations): hid_acts = acts_info.hidden_acts.detach() layer = acts_info.layer pr_overall[layer] = measure_pr(hid_acts) for lbl in range(pwl_prod.output_dim): pr_by_label[lbl, layer] = measure_pr(hid_acts[masks_by_lbl[lbl]]) network(sample_points, on_hidacts) return PRTrajectory(overall=pr_overall, by_label=pr_by_label, layers=True)
def find_trajectory(model: Network, pwl_prod: PointWithLabelProducer, num_pcs: int, recur_times: int = 10, points_dtype: typing.Any = torch.float, out_dtype: typing.Any = torch.float, squeeze_to_pwl: bool = True) -> PCTrajectoryGen: """Finds the PC trajectory for the given network, sampling a reasonable number of points from the dataset. The resulting labels will simply be the shape of the output layer for the network. Args: model (Network): the network whose activations will be fetched pwl_prod (PointWithLabelProducer): the dataset to sample inputs from num_pcs (int): the number of principal components to project onto recur_times (int, optional): if the network is recurrent, this is how many timesteps to get the trajectory on. Default 10. points_dtype (torch dtype, optional): the data type for points. Default torch.float. out_dtype (torch dtype, optional): the data type to use. Default torch.float. squeeze_to_pwl (bool, optional): If the output is squeezed before filled by the point with label producer. Helpful if the output is a single scalar and the pwl expects a tensor with shape (num_samples). Default True """ tus.check( model=(model, Network), pwl_prod=(pwl_prod, PointWithLabelProducer), num_pcs=(num_pcs, int), recur_times=(recur_times, int) ) num_samples = min(pwl_prod.epoch_size, 2000) points = torch.zeros((num_samples, pwl_prod.input_dim), dtype=points_dtype) out = torch.zeros((num_samples, model.output_dim), dtype=out_dtype) pwl_prod.fill(points, out.squeeze() if squeeze_to_pwl else out) snapshots = [] def on_hidacts_raw(hid_acts: torch.tensor): if hid_acts.shape[1] >= num_pcs: pc_vals, pc_vecs = pca.get_hidden_pcs(hid_acts, num_pcs) projected = pca.project_to_pcs(hid_acts, pc_vecs, out=None) snapshots.append(PCTrajectoryGenSnapshot(pc_vecs, pc_vals, projected, out)) elif hid_acts.shape[1] > 1: pc_vals, pc_vecs = pca.get_hidden_pcs(hid_acts, hid_acts.shape[1]) projected = pca.project_to_pcs(hid_acts, pc_vecs, out=None) pc_vecs_app = torch.zeros((num_pcs, hid_acts.shape[1]), dtype=pc_vecs.dtype) pc_vecs_app[:pc_vecs.shape[0]] = pc_vecs pc_vals_app = torch.zeros((num_pcs,), dtype=pc_vals.dtype) pc_vals_app[:pc_vals.shape[0]] = pc_vals projected_app = torch.zeros((hid_acts.shape[0], num_pcs), dtype=projected.dtype) projected_app[:projected.shape[0]] = projected snapshots.append(PCTrajectoryGenSnapshot(pc_vecs_app, pc_vals_app, projected_app, out)) else: pc_vecs = torch.zeros((num_pcs, hid_acts.shape[1]), dtype=hid_acts.dtype) pc_vals = torch.zeros((num_pcs,), dtype=hid_acts.dtype) pc_vals[0] = 1 pc_vecs[0, 0] = 1 projected = torch.zeros((hid_acts.shape[0], num_pcs), dtype=hid_acts.dtype) projected[:, 0] = hid_acts.squeeze() snapshots.append(PCTrajectoryGenSnapshot(pc_vecs, pc_vals, projected, out)) if isinstance(model, NaturalRNN): def on_hidacts_rnn(hacts: RNNHiddenActivations): on_hidacts_raw(hacts.hidden_acts.detach()) model(points, recur_times, on_hidacts_rnn, 1) else: def on_hidacts_ff(facts: FFHiddenActivations): on_hidacts_raw(facts.hidden_acts.detach()) model(points, on_hidacts_ff) return PCTrajectoryGen(snapshots)
def find_trajectory(model: NaturalRNN, pwl_prod: PointWithLabelProducer, duration: int, num_pcs: int) -> PCTrajectory: """Finds the trajectory of the given model using the given point with label producer. Goes through the entire epoch for the point with label producer. Args: model (Natural): The underlying model whose trajectories are being considered pwl_prod (PointWithLabelProducer): The producer for the samples. The entire epoch is gone through duration (int): How many timesteps to go through num_pcs (int): The number of principal vectors to find """ num_samples = min(pwl_prod.epoch_size, 100 * pwl_prod.output_dim) sample_points = torch.zeros((num_samples, model.input_dim), dtype=torch.double) sample_labels = torch.zeros((num_samples, ), dtype=torch.long) hid_acts = torch.zeros((duration + 1, num_samples, model.hidden_dim), dtype=torch.double) hid_pc_vals = torch.zeros((duration + 1, num_pcs), dtype=torch.double) hid_pc_vecs = torch.zeros((duration + 1, num_pcs, model.hidden_dim), dtype=torch.double) proj_samples = torch.zeros((duration + 1, num_samples, num_pcs), dtype=torch.double) pwl_prod.fill(sample_points, sample_labels) def on_hidacts(acts_info: RNNHiddenActivations): hidden_acts = acts_info.hidden_acts recur_step = acts_info.recur_step hid_acts[recur_step, :, :] = hidden_acts.detach() pc_vals, pc_vecs = get_hidden_pcs(hid_acts[recur_step], num_pcs) hid_pc_vals[recur_step, :] = pc_vals hid_pc_vecs[recur_step, :, :] = pc_vecs model(sample_points, duration, on_hidacts, 1) for recur_step in range(duration + 1): project_to_pcs(hid_acts[recur_step], hid_pc_vecs[recur_step], out=proj_samples[recur_step]) # We are free to rotate the pc vectors as we please. The following rotates # them such that the mean value of each label on each pc stays on the same # side. This is only 100% accomplishable for 2 labels since we must swap # the direction for ALL labels for a particular pc. we ensure that at least # 50% of the labels did not change direction from the start indices_by_label = dict( (lbl, sample_labels == lbl) for lbl in range(pwl_prod.output_dim)) means_by_label_and_recur = dict() for lbl in range(pwl_prod.output_dim): for recur_step in range(duration + 1): means_by_label_and_recur[(lbl, recur_step)] = ( proj_samples[recur_step, indices_by_label[lbl], :].mean(0)) # we can flip any of our pcs for recur_step in range(1, duration + 1): for pc in range(num_pcs): # pylint: disable=invalid-name badness = 0 counter = 0 for lbl1 in range(pwl_prod.output_dim): for lbl2 in range(lbl1 + 1, pwl_prod.output_dim): used_to_be_lt = (means_by_label_and_recur[(lbl1, 0)][pc] < means_by_label_and_recur[(lbl2, 0)][pc]) curr_is_lt = ( means_by_label_and_recur[(lbl1, recur_step)][pc] < means_by_label_and_recur[(lbl2, recur_step)][pc]) counter += 1 if used_to_be_lt != curr_is_lt: badness += 1 if badness >= (counter / 2): hid_pc_vecs[recur_step, pc, :] *= -1 for recur_step in range(duration + 1): project_to_pcs(hid_acts[recur_step], hid_pc_vecs[recur_step], out=proj_samples[recur_step]) return PCTrajectory(hid_pc_vecs, hid_pc_vals, proj_samples, sample_labels, duration)
def measure_dtt(model: NaturalRNN, pwl_prod: PointWithLabelProducer, duration: int, outfile: str, exist_ok: bool = False, logger: logging.Logger = None, verbose: bool = False) -> None: """Measures the distance of points in hidden activation space from points which are sampled from different labels. For example, points from the same label might be close in hidden activation space even when they are far in input space. Args: model (NaturalRNN): the model pwl_prod (PointWithLabelProducer): the points to test duration (int): the number of recurrent times outfile (str): where to save the result. Should have extension '.zip' or no extension at all exist_ok (bool, default false): true if we should overwrite the outfile if it already exists, false to check if it exists and error if it does """ if not isinstance(model, NaturalRNN): raise ValueError(f'expected model is NaturalRNN, got {model} (type={type(model)})') if not isinstance(pwl_prod, PointWithLabelProducer): raise ValueError(f'expected pwl is PointWithLabelProducer, got {pwl_prod} (type=({type(pwl_prod)})') if not isinstance(outfile, str): raise ValueError(f'expected outfile is str, got {outfile} (type={type(outfile)})') outfile_wo_ext = os.path.splitext(outfile)[0] if outfile_wo_ext == outfile: outfile = outfile_wo_ext + '.zip' if os.path.exists(outfile_wo_ext): raise FileExistsError(f'for outfile={outfile}, need {outfile_wo_ext} as working space') if not exist_ok and os.path.exists(outfile): raise FileExistsError(f'outfile {outfile} already exists (use exist_ok=True) to overwrite') num_samples = min(pwl_prod.epoch_size, 50 * pwl_prod.output_dim) sample_points = torch.zeros((num_samples, model.input_dim), dtype=torch.double) sample_labels = torch.zeros((num_samples,), dtype=torch.long) hid_acts = torch.zeros((duration+1, num_samples, model.hidden_dim), dtype=torch.double) within_dists = [] # each value corresponds to a torch tensor of within dists within_means = torch.zeros(duration+1, dtype=torch.double) within_stds = torch.zeros(duration+1, dtype=torch.double) within_sems = torch.zeros(duration+1, dtype=torch.double) across_dists = [] # each value corresponds to a torch tensor of across dists across_means = torch.zeros(duration+1, dtype=torch.double) across_stds = torch.zeros(duration+1, dtype=torch.double) across_sems = torch.zeros(duration+1, dtype=torch.double) pwl_prod.fill(sample_points, sample_labels) def on_hidacts(acts_info: RNNHiddenActivations): hidden_acts = acts_info.hidden_acts recur_step = acts_info.recur_step hid_acts[recur_step, :, :] = hidden_acts.detach() within, across = measure_instant(hid_acts[recur_step], sample_labels, pwl_prod.output_dim) within_dists.append(within) across_dists.append(across) within_means[recur_step] = within.mean() within_stds[recur_step] = within.std() within_sems[recur_step] = within_stds[recur_step] / np.sqrt(num_samples) across_means[recur_step] = across.mean() across_stds[recur_step] = across.std() across_sems[recur_step] = across_stds[recur_step] / np.sqrt(num_samples) _dbg(verbose, logger, 'measure_dtt getting raw data') model(sample_points, duration, on_hidacts, 1) within_col, across_col = 'tab:cyan', 'r' fig_mean_with_stddev, ax_mean_with_stddev = plt.subplots() fig_mean_with_sem, ax_mean_with_sem = plt.subplots() fig_mean_with_scatter, ax_mean_with_scatter = plt.subplots() ax_mean_with_stddev.set_title('Distances Through Time (error: 1.96 std dev)') ax_mean_with_sem.set_title('Distances Through Time (error: 1.96 sem)') ax_mean_with_scatter.set_title('Distances Through Time') for ax in (ax_mean_with_stddev, ax_mean_with_sem, ax_mean_with_scatter): ax.set_xlabel('Time (recurrent steps occurred)') ax.set_ylabel('Distance (euclidean)') recur_steps = np.arange(duration+1) _dbg(verbose, logger, 'measure_dtt plotting mean_with_stddev') ax_mean_with_stddev.errorbar(recur_steps, within_means.numpy(), within_stds.numpy() * 1.96, color=within_col, label='Within') ax_mean_with_stddev.errorbar(recur_steps, across_means.numpy(), across_stds.numpy() * 1.96, color=across_col, label='Across') _dbg(verbose, logger, 'measure_dtt plotting mean_with_sem') ax_mean_with_sem.errorbar(recur_steps, within_means.numpy(), within_sems.numpy() * 1.96, color=within_col, label='Within') ax_mean_with_sem.errorbar(recur_steps, across_means.numpy(), across_sems.numpy() * 1.96, color=across_col, label='Across') _dbg(verbose, logger, 'measure_dtt plotting mean_with_scatter') ax_mean_with_scatter.plot(recur_steps, within_means.numpy(), color=within_col, label='Within') ax_mean_with_scatter.plot(recur_steps, across_means.numpy(), color=across_col, label='Across') for recur_step in range(duration+1): xvals = np.zeros(within_dists[recur_step].shape, dtype='uint8') + recur_step ax_mean_with_scatter.scatter(xvals, within_dists[recur_step], 1, within_col, alpha=0.3) xvals = np.zeros(across_dists[recur_step].shape, dtype='uint8') + recur_step ax_mean_with_scatter.scatter(xvals, across_dists[recur_step], 1, across_col, alpha=0.3) _dbg(verbose, logger, 'measure_dtt saving and cleaning up') for ax in (ax_mean_with_stddev, ax_mean_with_sem, ax_mean_with_scatter): ax.legend() ax.set_xticks(recur_steps) for fig in (fig_mean_with_stddev, fig_mean_with_sem, fig_mean_with_scatter): fig.tight_layout() os.makedirs(outfile_wo_ext) fig_mean_with_stddev.savefig(os.path.join(outfile_wo_ext, 'mean_with_stddev.png'), transparent=True) fig_mean_with_sem.savefig(os.path.join(outfile_wo_ext, 'mean_with_sem.png'), transparent=True) fig_mean_with_scatter.savefig(os.path.join(outfile_wo_ext, 'mean_with_scatter.png'), transparent=True) plt.close(fig_mean_with_stddev) plt.close(fig_mean_with_sem) plt.close(fig_mean_with_scatter) np.savez(os.path.join(outfile_wo_ext, 'data.npz'), sample_points=sample_points.numpy(), sample_labels=sample_labels.numpy(), hid_acts=hid_acts.numpy() ) np.savez(os.path.join(outfile_wo_ext, 'within.npz'), *tuple(wd.numpy() for wd in within_dists)) np.savez(os.path.join(outfile_wo_ext, 'across.npz'), *tuple(ad.numpy() for ad in across_dists)) np.savetxt(os.path.join(outfile_wo_ext, 'within_means.txt'), within_means.numpy()) np.savetxt(os.path.join(outfile_wo_ext, 'across_means.txt'), across_means.numpy()) if os.path.exists(outfile): os.remove(outfile) cwd = os.getcwd() shutil.make_archive(outfile_wo_ext, 'zip', outfile_wo_ext) os.chdir(cwd) shutil.rmtree(outfile_wo_ext) os.chdir(cwd)
def measure_dtt_ff(model: FeedforwardNetwork, pwl_prod: PointWithLabelProducer, outfile: str, exist_ok: bool = False, logger: typing.Optional[logging.Logger] = None, verbose: bool = False) -> None: """Analogue to measure_dtt for feed-forward networks""" if not isinstance(model, FeedforwardNetwork): raise ValueError(f'expected model is FeedforwardNetwork, got {model} (type={type(model)})') if not isinstance(pwl_prod, PointWithLabelProducer): raise ValueError(f'expected pwl_prod is PointWithLabelProducer, got {pwl_prod} (type={type(pwl_prod)})') if not isinstance(outfile, str): raise ValueError(f'expected outfile is str, got {outfile}') if not isinstance(exist_ok, bool): raise ValueError(f'expected exist_ok is bool, got {exist_ok}') if logger is not None and not isinstance(logger, logging.Logger): raise ValueError(f'expected logger is optional[logging.Logger], got {logger} (type={type(logger)})') if not isinstance(verbose, bool): raise ValueError(f'expected verbose is bool, got {verbose} (type={type(verbose)})') outfile_wo_ext = os.path.splitext(outfile)[0] if outfile_wo_ext == outfile: outfile = outfile_wo_ext + '.zip' if os.path.exists(outfile_wo_ext): raise FileExistsError(f'for outfile={outfile}, need {outfile_wo_ext} as working space') if not exist_ok and os.path.exists(outfile): raise FileExistsError(f'outfile {outfile} already exists (use exist_ok=True) to overwrite') num_samples = min(pwl_prod.epoch_size, 50 * pwl_prod.output_dim) sample_points = torch.zeros((num_samples, model.input_dim), dtype=torch.double) sample_labels = torch.zeros((num_samples,), dtype=torch.long) hid_acts = [] # each will be 2d tensor within_dists = [] # each value corresponds to a torch tensor of within dists within_means = torch.zeros(model.num_layers+1, dtype=torch.double) within_stds = torch.zeros(model.num_layers+1, dtype=torch.double) within_sems = torch.zeros(model.num_layers+1, dtype=torch.double) across_dists = [] # each value corresponds to a torch tensor of across dists across_means = torch.zeros(model.num_layers+1, dtype=torch.double) across_stds = torch.zeros(model.num_layers+1, dtype=torch.double) across_sems = torch.zeros(model.num_layers+1, dtype=torch.double) pwl_prod.mark() pwl_prod.fill(sample_points, sample_labels) pwl_prod.reset() def on_hidacts(acts_info: FFHiddenActivations): hidden_acts = acts_info.hidden_acts layer = acts_info.layer hid_acts.append(hidden_acts.detach()) within, across = measure_instant(hid_acts[layer], sample_labels, pwl_prod.output_dim) within_dists.append(within) across_dists.append(across) within_means[layer] = within.mean() within_stds[layer] = within.std() within_sems[layer] = within_stds[layer] / np.sqrt(num_samples) across_means[layer] = across.mean() across_stds[layer] = across.std() across_sems[layer] = across_stds[layer] / np.sqrt(num_samples) _dbg(verbose, logger, 'measure_dtt_ff getting raw data') model(sample_points, on_hidacts) layers = np.arange(model.num_layers+1) _plot_dtt_ff(layers, within_means, within_stds, within_sems, across_means, across_stds, across_sems, within_dists, across_dists, outfile_wo_ext, verbose, logger) _save_dtt_ff(sample_points, sample_labels, hid_acts, within_dists, across_dists, outfile_wo_ext) if os.path.exists(outfile): os.remove(outfile) zipdir(outfile_wo_ext)