def custom_load_representations(representation_fname_l,
                                limit=None,
                                layerspec_l=None,
                                first_half_only_l=False,
                                second_half_only_l=False):
    """
    Load in representations. Options to control loading exist. 

    Params:
    ----
    representation_fname_l : list<str>
        List of hdf5 files containing representations
    limit : int or None
        Limit on number of representations to take
    layerspec_l : list
        Specification for each model. May be an integer (layer to take),
        or "all" or "full". "all" means take all layers. "full" means to
        concatenate all layers together.
    first_half_only_l : list<bool>
        Only take the first half of the representations for a given
        model.
        
        If given a single value, will be copied into a list of the
        correct length.
    second_half_only_l : list<bool>
        Only take the second half of the representations for a given
        model. 

        If given a single value, will be copied into a list of the
        correct length.

    Returns:
    ----
    num_neuron_d : {str : int}
        {network : number of neurons}. Here a network could be a layer,
        or the stack of all layers, etc. A network is what's being
        correlated as a single unit.
    representations_d : {str : tensor}
        {network : activations}. 
    """

    # Edit args
    l = len(representation_fname_l)
    if layerspec_l is None:
        layerspec_l = ['all'] * l
    if type(first_half_only_l) is not list:
        first_half_only_l = [first_half_only_l] * l
    if type(second_half_only_l) is not list:
        second_half_only_l = [second_half_only_l] * l

    # Main loop
    num_neurons_d = {}
    representations_d = {}
    for loop_var in tqdm(
            zip(representation_fname_l, layerspec_l, first_half_only_l,
                second_half_only_l)):
        fname, layerspec, first_half_only, second_half_only = loop_var

        # Set `activations_h5`, `sentence_d`, `indices`
        activations_h5 = h5py.File(fname, 'r')
        indices = list(activations_h5.keys())[:limit]

        # Set `num_layers`, `num_neurons`, `layers`
        s = activations_h5[indices[0]].shape
        num_layers = 1 if len(s) == 2 else s[0]
        num_neurons = s[-1]
        if layerspec == "all":
            layers = list(range(num_layers))
        elif layerspec == "full":
            layers = ["full"]
        else:
            layers = [layerspec]

        # Set `num_neurons_d`, `representations_d`
        for layer in layers:
            # Create `representations_l`
            representations_l = []
            word_count = 0
            for sentence_ix in indices:
                # Set `dim`, `n_word`, update `word_count`
                shape = activations_h5[sentence_ix].shape
                dim = len(shape)
                if not (dim == 2 or dim == 3):
                    raise ValueError('Improper array dimension in file: ' +
                                     fname + "\nShape: " +
                                     str(activations_h5[sentence_ix].shape))
                if dim == 3:
                    n_word = shape[1]
                elif dim == 2:
                    n_word = shape[0]
                word_count += n_word

                # Create `activations`
                if layer == "full":
                    activations = torch.FloatTensor(
                        activations_h5[sentence_ix])
                    if dim == 3:
                        activations = activations.permute(1, 0, 2)
                        activations = activations.contiguous().view(n_word, -1)
                else:
                    activations = torch.FloatTensor(
                        activations_h5[sentence_ix][layer] if dim ==
                        3 else activations_h5[sentence_ix])

                # Create `representations`
                representations = activations
                if first_half_only:
                    representations = torch.chunk(representations,
                                                  chunks=2,
                                                  dim=-1)[0]
                elif second_half_only:
                    representations = torch.chunk(representations,
                                                  chunks=2,
                                                  dim=-1)[1]
                representations_l.append(representations)
                # print("{mname}_{layer}".format(mname=fname2mname(fname), layer=layer),
                #     representations.shape)

                # Early stop
                if limit is not None and word_count >= limit:
                    break

            # Main update
            network = "{mname}_{layer}".format(mname=fname2mname(fname),
                                               layer=layer)
            num_neurons_d[network] = representations_l[0].size()[-1]
            representations_d[network] = torch.cat(representations_l)[:limit]

    return num_neurons_d, representations_d
def load_attentions(attention_fname_l,
                    limit=None,
                    layerspec_l=None,
                    ar_mask=False):
    """
    Load in attentions. Options to control loading exist. 

    Params:
    ----
    attention_fname_l : list<str>
        List of hdf5 files containing attentions
    limit : int or None
        Limit on number of attentions to take
    layerspec_l : list
        Specification for each model. May be an integer (layer to take),
        or "all". "all" means take all layers. 
    ar_mask : bool
        Whether to mask the future when loading. Some models (eg. gpt)
        do this automatically.

    Returns:
    ----
    num_head_d : {str : int}
        {network : number of heads}. Here a network could be a layer,
        or the stack of all layers, etc. A network is what's being
        correlated as a single unit.
    attentions_d : {str : list<tensor>}
        {network : attentions}. attentions is a list because each 
        sentence may be of different length. 
    """
    # Edit args
    l = len(attention_fname_l)
    if layerspec_l is None:
        layerspec_l = ['all'] * l

    # Main loop
    num_heads_d = {}
    attentions_d = {}
    for loop_var in tqdm(zip(attention_fname_l, layerspec_l), desc='load'):
        fname, layerspec = loop_var

        # Set `attentions_h5`, `sentence_d`, `indices`
        attentions_h5 = h5py.File(fname, 'r')
        sentence_d = json.loads(attentions_h5['sentence_to_index'][0])
        temp = {}  # TO DO: Make this more elegant?
        for k, v in sentence_d.items():
            temp[v] = k
        sentence_d = temp  # {str ix, sentence}
        indices = list(sentence_d.keys())[:limit]

        # Set `num_layers`, `num_heads`, `layers`
        s = attentions_h5[indices[0]].shape
        num_layers = s[0]
        num_heads = s[1]
        if layerspec == "all":
            layers = list(range(num_layers))
        else:
            layers = [layerspec]

        # Set `num_heads_d`, `attentions_d`
        for layer in layers:
            # Create `attentions_l`
            attentions_l = []
            word_count = 0
            for sentence_ix in indices:
                # Set `dim`, `n_word`, update `word_count`
                shape = attentions_h5[sentence_ix].shape
                dim = len(shape)
                if not (dim == 4):
                    raise ValueError('Improper array dimension in file: ' +
                                     fname + "\nShape: " +
                                     str(attentions_h5[sentence_ix].shape))
                n_word = shape[2]
                word_count += n_word

                # Create `attentions`
                if ar_mask:
                    attentions = np.tril(attentions_h5[sentence_ix][layer])
                    attentions = attentions / np.sum(
                        attentions, axis=-1, keepdims=True)
                    attentions = torch.FloatTensor(attentions)
                else:
                    attentions = torch.FloatTensor(
                        attentions_h5[sentence_ix][layer])

                # Update `attentions_l`
                attentions_l.append(attentions)

                # Early stop
                if limit is not None and word_count >= limit:
                    break

            # Main update
            network = "{mname}_{layer}".format(mname=fname2mname(fname),
                                               layer=layer)
            num_heads_d[network] = attentions_l[0].shape[0]
            attentions_d[network] = attentions_l[:limit]

    return num_heads_d, attentions_d