def custom_load_representations(representation_fname_l, limit=None, layerspec_l=None, first_half_only_l=False, second_half_only_l=False): """ Load in representations. Options to control loading exist. Params: ---- representation_fname_l : list<str> List of hdf5 files containing representations limit : int or None Limit on number of representations to take layerspec_l : list Specification for each model. May be an integer (layer to take), or "all" or "full". "all" means take all layers. "full" means to concatenate all layers together. first_half_only_l : list<bool> Only take the first half of the representations for a given model. If given a single value, will be copied into a list of the correct length. second_half_only_l : list<bool> Only take the second half of the representations for a given model. If given a single value, will be copied into a list of the correct length. Returns: ---- num_neuron_d : {str : int} {network : number of neurons}. Here a network could be a layer, or the stack of all layers, etc. A network is what's being correlated as a single unit. representations_d : {str : tensor} {network : activations}. """ # Edit args l = len(representation_fname_l) if layerspec_l is None: layerspec_l = ['all'] * l if type(first_half_only_l) is not list: first_half_only_l = [first_half_only_l] * l if type(second_half_only_l) is not list: second_half_only_l = [second_half_only_l] * l # Main loop num_neurons_d = {} representations_d = {} for loop_var in tqdm( zip(representation_fname_l, layerspec_l, first_half_only_l, second_half_only_l)): fname, layerspec, first_half_only, second_half_only = loop_var # Set `activations_h5`, `sentence_d`, `indices` activations_h5 = h5py.File(fname, 'r') indices = list(activations_h5.keys())[:limit] # Set `num_layers`, `num_neurons`, `layers` s = activations_h5[indices[0]].shape num_layers = 1 if len(s) == 2 else s[0] num_neurons = s[-1] if layerspec == "all": layers = list(range(num_layers)) elif layerspec == "full": layers = ["full"] else: layers = [layerspec] # Set `num_neurons_d`, `representations_d` for layer in layers: # Create `representations_l` representations_l = [] word_count = 0 for sentence_ix in indices: # Set `dim`, `n_word`, update `word_count` shape = activations_h5[sentence_ix].shape dim = len(shape) if not (dim == 2 or dim == 3): raise ValueError('Improper array dimension in file: ' + fname + "\nShape: " + str(activations_h5[sentence_ix].shape)) if dim == 3: n_word = shape[1] elif dim == 2: n_word = shape[0] word_count += n_word # Create `activations` if layer == "full": activations = torch.FloatTensor( activations_h5[sentence_ix]) if dim == 3: activations = activations.permute(1, 0, 2) activations = activations.contiguous().view(n_word, -1) else: activations = torch.FloatTensor( activations_h5[sentence_ix][layer] if dim == 3 else activations_h5[sentence_ix]) # Create `representations` representations = activations if first_half_only: representations = torch.chunk(representations, chunks=2, dim=-1)[0] elif second_half_only: representations = torch.chunk(representations, chunks=2, dim=-1)[1] representations_l.append(representations) # print("{mname}_{layer}".format(mname=fname2mname(fname), layer=layer), # representations.shape) # Early stop if limit is not None and word_count >= limit: break # Main update network = "{mname}_{layer}".format(mname=fname2mname(fname), layer=layer) num_neurons_d[network] = representations_l[0].size()[-1] representations_d[network] = torch.cat(representations_l)[:limit] return num_neurons_d, representations_d
def load_attentions(attention_fname_l, limit=None, layerspec_l=None, ar_mask=False): """ Load in attentions. Options to control loading exist. Params: ---- attention_fname_l : list<str> List of hdf5 files containing attentions limit : int or None Limit on number of attentions to take layerspec_l : list Specification for each model. May be an integer (layer to take), or "all". "all" means take all layers. ar_mask : bool Whether to mask the future when loading. Some models (eg. gpt) do this automatically. Returns: ---- num_head_d : {str : int} {network : number of heads}. Here a network could be a layer, or the stack of all layers, etc. A network is what's being correlated as a single unit. attentions_d : {str : list<tensor>} {network : attentions}. attentions is a list because each sentence may be of different length. """ # Edit args l = len(attention_fname_l) if layerspec_l is None: layerspec_l = ['all'] * l # Main loop num_heads_d = {} attentions_d = {} for loop_var in tqdm(zip(attention_fname_l, layerspec_l), desc='load'): fname, layerspec = loop_var # Set `attentions_h5`, `sentence_d`, `indices` attentions_h5 = h5py.File(fname, 'r') sentence_d = json.loads(attentions_h5['sentence_to_index'][0]) temp = {} # TO DO: Make this more elegant? for k, v in sentence_d.items(): temp[v] = k sentence_d = temp # {str ix, sentence} indices = list(sentence_d.keys())[:limit] # Set `num_layers`, `num_heads`, `layers` s = attentions_h5[indices[0]].shape num_layers = s[0] num_heads = s[1] if layerspec == "all": layers = list(range(num_layers)) else: layers = [layerspec] # Set `num_heads_d`, `attentions_d` for layer in layers: # Create `attentions_l` attentions_l = [] word_count = 0 for sentence_ix in indices: # Set `dim`, `n_word`, update `word_count` shape = attentions_h5[sentence_ix].shape dim = len(shape) if not (dim == 4): raise ValueError('Improper array dimension in file: ' + fname + "\nShape: " + str(attentions_h5[sentence_ix].shape)) n_word = shape[2] word_count += n_word # Create `attentions` if ar_mask: attentions = np.tril(attentions_h5[sentence_ix][layer]) attentions = attentions / np.sum( attentions, axis=-1, keepdims=True) attentions = torch.FloatTensor(attentions) else: attentions = torch.FloatTensor( attentions_h5[sentence_ix][layer]) # Update `attentions_l` attentions_l.append(attentions) # Early stop if limit is not None and word_count >= limit: break # Main update network = "{mname}_{layer}".format(mname=fname2mname(fname), layer=layer) num_heads_d[network] = attentions_l[0].shape[0] attentions_d[network] = attentions_l[:limit] return num_heads_d, attentions_d