コード例 #1
0
    def __init__(self,
                 activation_func,
                 optimizer,
                 lr: float,
                 title: str,
                 input_size: int,
                 hidden_size: int,
                 device: str,
                 deterministic=True,
                 weights_init=functions.init_params,
                 prevbatch=False,
                 conv=False,
                 seed=0):

        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)

        ModelState.__init__(
            self,
            Network(input_size,
                    hidden_size,
                    activation_func,
                    weights_init=weights_init,
                    prevbatch=prevbatch,
                    conv=conv).to(device), optimizer, lr, title, {
                        "train loss": np.zeros(0),
                        "test loss": np.zeros(0)
                    }, device)
コード例 #2
0
def train_batch(ms: ModelState, batch: torch.FloatTensor,
                loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor],
                                  torch.FloatTensor], state) -> float:

    loss, res, state = ms.run(batch, loss_fn, state)

    ms.step(loss)
    ms.zero_grad()
    return loss.item(), res, state
コード例 #3
0
def example_sequence_state(net: ModelState,
                           dataset: Dataset,
                           latent=False,
                           seed=2553,
                           save=True):
    """
    visualises input and internal drive for a sample sequence
    """
    if seed != None:
        torch.manual_seed(seed)
        np.random.seed(seed)

    batches, _ = dataset.create_batches(batch_size=1,
                                        sequence_length=10,
                                        shuffle=False)

    seq = batches[0, :, :, :]
    input_size = seq.shape[
        -1]  # make sure we only visualize input units and no latent resources
    X = []
    P = []
    L = []

    h = net.model.init_state(seq.shape[1])

    for x in seq:

        p = net.predict(h, latent)
        h, l_a = net.model(x, state=h)

        X.append(x[:input_size].mean(dim=0).detach().cpu())
        P.append(p[:, :input_size].mean(dim=0).detach().cpu())
        if latent:  # look at latent unit drive
            L.append(p[:, input_size:].mean(dim=0).detach().cpu())

    fig = plt.figure(figsize=(3, 3))
    if latent:
        fig, axes = display(X + P + L,
                            shape=(10, 3),
                            figsize=(3, 3),
                            axes_visible=False,
                            layout='tight')
    else:
        fig, axes = display(X + P,
                            shape=(10, 2),
                            figsize=(3, 3),
                            axes_visible=False,
                            layout='tight')

    if save is True:
        save_fig(fig, "example_sequence_state", bbox_inches='tight')
コード例 #4
0
def pred_after_timestep(net: ModelState,
                        dataset: Dataset,
                        mask=None,
                        seed=2553,
                        save=True):
    """
    visualises internal drive after 0-9 preceding frames.
    """
    if seed != None:
        torch.manual_seed(seed)
        np.random.seed(seed)

    imgs = []

    for d in range(10):
        imgs = imgs + [net.predict(torch.zeros(1, 784))] + [
            net.predict(
                helper._run_seq_from_digit(d, i, net, dataset,
                                           mask=mask)).mean(dim=0)
            for i in range(1, 10)
        ]

    fig, axes = display(imgs, shape=(10, 10), axes_visible=False)

    for i in range(10):
        axes[i][0].set_ylabel(str(i) + ":",
                              ha='center',
                              fontsize=60,
                              rotation=0,
                              labelpad=30)

    fig.tight_layout()

    if save:
        save_fig(fig,
                 net.title + "/pred-after-timestep" +
                 ("-lesioned" if mask is not None else ""),
                 bbox_inches='tight')
コード例 #5
0
def train(ms: ModelState,
          train_ds: Dataset,
          test_ds: Dataset,
          loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor],
                            torch.FloatTensor],
          num_epochs: int = 1,
          batch_size: int = 32,
          sequence_length: int = 3,
          verbose=False):

    for epoch in range(ms.epochs + 1, ms.epochs + 1 + num_epochs):
        print("Epoch {}".format(epoch))

        train_loss, train_res = train_epoch(ms,
                                            train_ds,
                                            loss_fn,
                                            batch_size,
                                            sequence_length,
                                            verbose=verbose)

        test_loss, test_res = test_epoch(ms, test_ds, loss_fn, batch_size,
                                         sequence_length)

        ms.on_results(epoch, train_res, test_res)
コード例 #6
0
def _run_seq_from_digit(digit, steps, net:ModelState, dataset:Dataset, mask=None):
    """Create sequences with the same starting digit through a model and return the hidden state

    Parameters:
        - digit: the last digit in the sequence
        - steps: sequence length, or steps before the sequence gets to the 'digit'
        - net: model
        - dataset: dataset to use
        - mask: mask can be used to turn off (i.e. lesion) certain units
    """
    fixed_starting_point = (digit - steps) % 10
    b, _ = dataset.create_batches(batch_size=-1, sequence_length=steps, shuffle=True, fixed_starting_point=fixed_starting_point)
    batch = b.squeeze(0)

    h = net.model.init_state(1)
    for i in range(steps):
        h, l_a = net.model(batch[i], state=h)
        if mask is not None:
            h = h * mask

    return h.detach()
コード例 #7
0
def compute_preact_stats(net:ModelState, dataset:Dataset, nclasses=10, ntime=10):
    """
    Computer for each unit the average final time point median preactivation and MAD
    for each class
    
    Output: preact_stats matrix n_units x nclasses x 2
    """

    preact_stats = torch.zeros((net.model.hidden_size, nclasses, 2)) 
    
    # generate sequences that end in the same class
    for t in [ntime - 1]: # only look at final time point (0-indexed)
        for category in range(nclasses):
            starting_point = int(category - t + ntime)
            if starting_point > (ntime - 1): # cycle back
                starting_point -= ntime
            
            data, labels = dataset.create_batches(-1,ntime, shuffle=False,fixed_starting_point=starting_point)
            
            nb, ntime,batch_size,ninputs = data.shape
     
            data = data.squeeze(0)
            labels = labels.squeeze(0)
            batch_size = data.shape[1]
            h_net = net.model.init_state(batch_size)
           
            for i in range(data.shape[0]): # calculate response variance of category up until t
                x = data[i]
                h_net, l_net = net.model(x, state=h_net)
                if i == t: 
                    med, mad= l_net[0].median(axis=0).values, torch.tensor(st.median_absolute_deviation(l_net[0].detach().numpy(), axis=0))
                  
            preact_stats[:, category, 0] = med
            preact_stats[:, category, 1] = mad
            
    return preact_stats.detach()  
コード例 #8
0
def model_activity_lesioned(net:ModelState, training_set:Dataset, test_set:Dataset, seq_length=10, save=True, 
                            latent=False, data_type='mnist', reverse=False):
    """
    calculates model preactivation  and preactivation bounds 
    for lesioned models 
    """
    mask = _pred_mask(net, test_set, training_set= training_set, latent=latent, reverse=reverse)
    nclass = 10 # change this if you want to change the number of classes
    # meds: class-specific medians, global_median: median of the entire data set
    if data_type == 'mnist':
        meds = mnist.medians(training_set)
        global_median = training_set.x.median(dim=0).values

    elif data_type == 'cifar':
        meds = cifar.medians(training_set)
        global_median = training_set.x.median(dim=0).values 
  
    with torch.no_grad():
        data, labels = test_set.create_batches(-1, seq_length, shuffle=True)
        nb, ntime,batch_size,ninputs = data.shape
        data = data.squeeze(0)
        labels = labels.squeeze(0)
        batch_size = data.shape[1]
       
        # result lists
        mu_notn = []
        mu_meds = []
        mu_gmed = []
        mu_net = []
        mu_netles=[]
        mu_input = []
        mu_latent = []

        h_net = net.model.init_state(batch_size)
        h_netles = net.model.init_state(batch_size)
        
       
        for t in range(data.shape[0]):
            x = data[t]
            y = labels[t]
            
            # repeat global median for each input image
            gmedian = torch.zeros_like(x)
            gmedian[:,:] = global_median
            
            # find the corresponding median for each input image
            median = torch.zeros_like(x)
            for i in range(nclass):
                median[y==i,:] = meds[i]

            # calculate hidden state
            h_meds = (x - median)
            h_gmed = (x - gmedian)
            h_net, l_net = net.model(x, state=h_net)
            h_netles = h_netles * mask # perform lesion
            h_netles, l_netles = net.model(x, state=h_netles)

            # calculate L1 loss for each unit, assuming equal amounts of units in each model
            m_notn = x.abs().sum(dim=1)/net.model.input_size
            m_meds = h_meds.abs().sum(dim=1)/net.model.input_size
            m_gmed = h_gmed.abs().sum(dim=1)/net.model.input_size
            m_net = torch.cat([a[:,:ninputs]for a in l_net], dim=1).abs().mean(dim=1) 
            m_netles = torch.cat([a[:,:ninputs] for a in l_netles], dim=1).abs().mean(dim=1)
            m_input = torch.cat([a[:,:ninputs] for a in l_netles], dim=1).abs().mean(dim=1).mean() 
            m_latent = torch.cat([a[:,ninputs:] for a in l_netles], dim=1).abs().mean(dim=1).mean()

            # Calculate the mean
            mu_notn.append(m_notn.mean().cpu().item())
            mu_meds.append(m_meds.mean().cpu().item())
            mu_gmed.append(m_gmed.mean().cpu().item())
            mu_net.append(m_net.mean())
            mu_netles.append(m_netles.mean())
            mu_input.append(m_input.mean().cpu().item())
            mu_latent.append(m_latent.mean().cpu().item())
            
         
           
    return data, np.array(mu_notn), np.array(mu_meds), np.array(mu_gmed), np.array(mu_net), np.array(mu_netles), np.array(mu_input), np.array(mu_latent)
コード例 #9
0
def test_batch(ms: ModelState, batch: torch.FloatTensor,
               loss_fn: Callable[[torch.FloatTensor, torch.FloatTensor],
                                 torch.FloatTensor], state) -> float:
    loss, res, state = ms.run(batch, loss_fn, state)
    return loss.item(), res, state