def __init__(self, activation_func, optimizer, lr: float, title: str, input_size: int, hidden_size: int, device: str, deterministic=True, weights_init=functions.init_params, prevbatch=False, conv=False, seed=0): if seed is not None: torch.manual_seed(seed) np.random.seed(seed) ModelState.__init__( self, Network(input_size, hidden_size, activation_func, weights_init=weights_init, prevbatch=prevbatch, conv=conv).to(device), optimizer, lr, title, { "train loss": np.zeros(0), "test loss": np.zeros(0) }, device)
def train_batch(ms: ModelState, batch: torch.FloatTensor, loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], state) -> float: loss, res, state = ms.run(batch, loss_fn, state) ms.step(loss) ms.zero_grad() return loss.item(), res, state
def example_sequence_state(net: ModelState, dataset: Dataset, latent=False, seed=2553, save=True): """ visualises input and internal drive for a sample sequence """ if seed != None: torch.manual_seed(seed) np.random.seed(seed) batches, _ = dataset.create_batches(batch_size=1, sequence_length=10, shuffle=False) seq = batches[0, :, :, :] input_size = seq.shape[ -1] # make sure we only visualize input units and no latent resources X = [] P = [] L = [] h = net.model.init_state(seq.shape[1]) for x in seq: p = net.predict(h, latent) h, l_a = net.model(x, state=h) X.append(x[:input_size].mean(dim=0).detach().cpu()) P.append(p[:, :input_size].mean(dim=0).detach().cpu()) if latent: # look at latent unit drive L.append(p[:, input_size:].mean(dim=0).detach().cpu()) fig = plt.figure(figsize=(3, 3)) if latent: fig, axes = display(X + P + L, shape=(10, 3), figsize=(3, 3), axes_visible=False, layout='tight') else: fig, axes = display(X + P, shape=(10, 2), figsize=(3, 3), axes_visible=False, layout='tight') if save is True: save_fig(fig, "example_sequence_state", bbox_inches='tight')
def pred_after_timestep(net: ModelState, dataset: Dataset, mask=None, seed=2553, save=True): """ visualises internal drive after 0-9 preceding frames. """ if seed != None: torch.manual_seed(seed) np.random.seed(seed) imgs = [] for d in range(10): imgs = imgs + [net.predict(torch.zeros(1, 784))] + [ net.predict( helper._run_seq_from_digit(d, i, net, dataset, mask=mask)).mean(dim=0) for i in range(1, 10) ] fig, axes = display(imgs, shape=(10, 10), axes_visible=False) for i in range(10): axes[i][0].set_ylabel(str(i) + ":", ha='center', fontsize=60, rotation=0, labelpad=30) fig.tight_layout() if save: save_fig(fig, net.title + "/pred-after-timestep" + ("-lesioned" if mask is not None else ""), bbox_inches='tight')
def train(ms: ModelState, train_ds: Dataset, test_ds: Dataset, loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], num_epochs: int = 1, batch_size: int = 32, sequence_length: int = 3, verbose=False): for epoch in range(ms.epochs + 1, ms.epochs + 1 + num_epochs): print("Epoch {}".format(epoch)) train_loss, train_res = train_epoch(ms, train_ds, loss_fn, batch_size, sequence_length, verbose=verbose) test_loss, test_res = test_epoch(ms, test_ds, loss_fn, batch_size, sequence_length) ms.on_results(epoch, train_res, test_res)
def _run_seq_from_digit(digit, steps, net:ModelState, dataset:Dataset, mask=None): """Create sequences with the same starting digit through a model and return the hidden state Parameters: - digit: the last digit in the sequence - steps: sequence length, or steps before the sequence gets to the 'digit' - net: model - dataset: dataset to use - mask: mask can be used to turn off (i.e. lesion) certain units """ fixed_starting_point = (digit - steps) % 10 b, _ = dataset.create_batches(batch_size=-1, sequence_length=steps, shuffle=True, fixed_starting_point=fixed_starting_point) batch = b.squeeze(0) h = net.model.init_state(1) for i in range(steps): h, l_a = net.model(batch[i], state=h) if mask is not None: h = h * mask return h.detach()
def compute_preact_stats(net:ModelState, dataset:Dataset, nclasses=10, ntime=10): """ Computer for each unit the average final time point median preactivation and MAD for each class Output: preact_stats matrix n_units x nclasses x 2 """ preact_stats = torch.zeros((net.model.hidden_size, nclasses, 2)) # generate sequences that end in the same class for t in [ntime - 1]: # only look at final time point (0-indexed) for category in range(nclasses): starting_point = int(category - t + ntime) if starting_point > (ntime - 1): # cycle back starting_point -= ntime data, labels = dataset.create_batches(-1,ntime, shuffle=False,fixed_starting_point=starting_point) nb, ntime,batch_size,ninputs = data.shape data = data.squeeze(0) labels = labels.squeeze(0) batch_size = data.shape[1] h_net = net.model.init_state(batch_size) for i in range(data.shape[0]): # calculate response variance of category up until t x = data[i] h_net, l_net = net.model(x, state=h_net) if i == t: med, mad= l_net[0].median(axis=0).values, torch.tensor(st.median_absolute_deviation(l_net[0].detach().numpy(), axis=0)) preact_stats[:, category, 0] = med preact_stats[:, category, 1] = mad return preact_stats.detach()
def model_activity_lesioned(net:ModelState, training_set:Dataset, test_set:Dataset, seq_length=10, save=True, latent=False, data_type='mnist', reverse=False): """ calculates model preactivation and preactivation bounds for lesioned models """ mask = _pred_mask(net, test_set, training_set= training_set, latent=latent, reverse=reverse) nclass = 10 # change this if you want to change the number of classes # meds: class-specific medians, global_median: median of the entire data set if data_type == 'mnist': meds = mnist.medians(training_set) global_median = training_set.x.median(dim=0).values elif data_type == 'cifar': meds = cifar.medians(training_set) global_median = training_set.x.median(dim=0).values with torch.no_grad(): data, labels = test_set.create_batches(-1, seq_length, shuffle=True) nb, ntime,batch_size,ninputs = data.shape data = data.squeeze(0) labels = labels.squeeze(0) batch_size = data.shape[1] # result lists mu_notn = [] mu_meds = [] mu_gmed = [] mu_net = [] mu_netles=[] mu_input = [] mu_latent = [] h_net = net.model.init_state(batch_size) h_netles = net.model.init_state(batch_size) for t in range(data.shape[0]): x = data[t] y = labels[t] # repeat global median for each input image gmedian = torch.zeros_like(x) gmedian[:,:] = global_median # find the corresponding median for each input image median = torch.zeros_like(x) for i in range(nclass): median[y==i,:] = meds[i] # calculate hidden state h_meds = (x - median) h_gmed = (x - gmedian) h_net, l_net = net.model(x, state=h_net) h_netles = h_netles * mask # perform lesion h_netles, l_netles = net.model(x, state=h_netles) # calculate L1 loss for each unit, assuming equal amounts of units in each model m_notn = x.abs().sum(dim=1)/net.model.input_size m_meds = h_meds.abs().sum(dim=1)/net.model.input_size m_gmed = h_gmed.abs().sum(dim=1)/net.model.input_size m_net = torch.cat([a[:,:ninputs]for a in l_net], dim=1).abs().mean(dim=1) m_netles = torch.cat([a[:,:ninputs] for a in l_netles], dim=1).abs().mean(dim=1) m_input = torch.cat([a[:,:ninputs] for a in l_netles], dim=1).abs().mean(dim=1).mean() m_latent = torch.cat([a[:,ninputs:] for a in l_netles], dim=1).abs().mean(dim=1).mean() # Calculate the mean mu_notn.append(m_notn.mean().cpu().item()) mu_meds.append(m_meds.mean().cpu().item()) mu_gmed.append(m_gmed.mean().cpu().item()) mu_net.append(m_net.mean()) mu_netles.append(m_netles.mean()) mu_input.append(m_input.mean().cpu().item()) mu_latent.append(m_latent.mean().cpu().item()) return data, np.array(mu_notn), np.array(mu_meds), np.array(mu_gmed), np.array(mu_net), np.array(mu_netles), np.array(mu_input), np.array(mu_latent)
def test_batch(ms: ModelState, batch: torch.FloatTensor, loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor], state) -> float: loss, res, state = ms.run(batch, loss_fn, state) return loss.item(), res, state