def train_model(exp_name,
                data_dict_path,
                model_path,
                cond_name=None,
                cond=None,
                run=None,
                hid_layers=1,
                units_per_layer=200,
                act_func='sigmoid',
                serial_recall=False,
                y_1hot=True,
                output_units='n_cats',
                generator=True,
                x_data_type='dist_letter_X',
                end_seq_cue=False,
                max_epochs=100,
                use_optimizer='adam',
                loss_target=0.01,
                min_loss_change=0.0001,
                batch_size=32,
                augmentation=True,
                grey_image=False,
                use_batch_norm=True,
                use_dropout=0.0,
                use_val_data=True,
                timesteps=1,
                train_cycles=False,
                weight_init='GlorotUniform',
                init_range=1,
                lr=0.001,
                unroll=False,
                LENS_states=False,
                exp_root='/home/nm13850/Documents/PhD/python_v2/experiments/',
                verbose=False,
                test_run=False):
    """
    script to train a recurrent neural network on a Short-term memory task.


    1. get details - dset_name, model, other key variables
    2. load datasets
    3. compile model
    4. fit/evaluate model
    5. make plots
    6. output:
        training plots
        model accuracy etc.

    this model will stop training if
    a. loss does not improve for [patience] epochs by [min_loss_change]
    b. accuracy reaches 100%
    c. when max_epochs is reached.


    :param exp_name: name of this experiemnt or set of model/dataset pairs
    :param data_dict_path: path to data dict
    :param model_path: dir for model
    :param cond_name: name of this condition
    :param cond: number for this condition
    :param run: number for this run (e.g., multiple runs of same conditin from different initializations)
    :param hid_layers: number of hidden layers
    :param units_per_layer: number of units in each layer
    :param act_func: activation function to use
    :param serial_recall: if True, output a sequence, else use dist output and no repeats in data.
    :param y_1hot: if True, output is 1hot, if False, output is multi_label.
    :param output_units: default='n_cats', e.g., number of categories.
                        If 'x_size', will be same as number of input feats.
                        Can also accept int.
    :param generator: If true, will generate training data, else load data as usual.
    :param x_data_type: input coding: local words (1hot), local letters (3hot), dist letters (9hot)
    :param end_seq_cue: Add input unit to cue recall
    :param max_epochs: Stop training after this many epochs
    :param use_optimizer: Optimizer to use
    :param loss_target: stop training when this target is reached
    :param min_loss_change: stop training of loss does not improve by this much
    :param batch_size: number of items loaded in at once
    :param augmentation: whether data aug is used (for images)
    :param grey_image: whether the images are grey (if false, they are colour)
    :param use_batch_norm: use batch normalization
    :param use_dropout: use dropout
    :param use_val_data: use validation set (either separate set, or train/val split)
    :param timesteps: if RNN length of sequence
    :param train_cycles: if False, all lists lengths = timesteps.
                        If True, train on varying length, [1, 2, 3,...timesteps].
    :param weight_init: change the initializatation of the weights
    :param init_range: Range for random uniform intializer (+/-, e.g., from 1 to -1)
    :param lr: set the learning rate for the optimizer
    :param unroll:  Whether to unroll the model.
    :param exp_root: root directory for saving experiments

    :param verbose: if 0, not verbose; if 1 - print basics; if 2, print all

    :return: training_info csv
    :return: sim_dict with dataset info, model info and training info

    """

    print("\n\n\nTraining a new model\n********************")

    dset_dir, data_dict_name = os.path.split(data_dict_path)
    dset_dir, dset_name = os.path.split(dset_dir)
    model_dir, model_name = os.path.split(model_path)

    print(f"dset_dir: {dset_dir}\ndset_name: {dset_name}")
    print(f"model_dir: {model_dir}\nmodel_name: {model_name}")

    # Output files
    if not cond_name:
        # output_filename = f"{exp_name}_{model_name}_{dset_name}"
        output_filename = f"{model_name}_{dset_name}"

    else:
        # output_filename = f"{exp_name}_{cond_name}"
        output_filename = cond_name

    print(f"\noutput_filename: {output_filename}")

    # # get info from dict
    if os.path.isfile(data_dict_path):
        data_dict = load_dict(data_dict_path)
    elif os.path.isfile(
            os.path.join('/home/nm13850/Documents/PhD/python_v2/datasets/',
                         data_dict_path)):
        # work computer
        data_dict_path = os.path.join(
            '/home/nm13850/Documents/PhD/python_v2/datasets/', data_dict_path)
        data_dict = load_dict(data_dict_path)
    elif os.path.isfile(
            os.path.join('/Users/nickmartin/Documents/PhD/python_v2/datasets',
                         data_dict_path)):
        # laptop
        data_dict_path = os.path.join(
            '/Users/nickmartin/Documents/PhD/python_v2/datasets',
            data_dict_path)
        data_dict = load_dict(data_dict_path)
    else:
        raise FileNotFoundError(data_dict_path)

    if verbose:
        # # print basic details
        print("\n**** STUDY DETAILS ****")
        print(
            f"output_filename: {output_filename}\ndset_name: {dset_name}\nmodel: {model_name}\n"
            f"max_epochs: {max_epochs}\nuse_optimizer: {use_optimizer}\n"
            f"lr: {lr}\n"
            f"loss_target: {loss_target}\nmin_loss_change: {min_loss_change}\n"
            f"batch_norm: {use_batch_norm}\nval_data: {use_val_data}\naugemntation: {augmentation}\n"
        )
        focussed_dict_print(data_dict, 'data_dict')

    n_cats = data_dict["n_cats"]
    x_size = data_dict['X_size']
    if end_seq_cue:
        x_size = x_size + 1

    if not generator:
        x_load = np.load(
            '/home/nm13850/Documents/PhD/python_v2/datasets/RNN/bowers14_rep/'
            'test_dist_letter_X_10batches_4seqs_3ts_31feat.npy')
        # print(f"x_load: {np.shape(x_load)}")
        y_load = np.load(
            '/home/nm13850/Documents/PhD/python_v2/datasets/RNN/bowers14_rep/'
            'test_freerecall_Y_10batches_4seqs_3ts.npy')
        # print(f"y_load: {np.shape(y_load)}")
        labels_load = np.load(
            '/home/nm13850/Documents/PhD/python_v2/datasets/RNN/bowers14_rep/'
            'test_freerecall_Y_labels_10batches_4seqs_3ts.npy')
        # print(f"labels_load: {np.shape(labels_load)}")
        x_train = x_load
        y_train = y_load

        n_items = np.shape(x_train)[0]
    else:
        # # if generator is true
        x_data_path = 'RNN_STM_tools/generate_STM_RNN_seqs'
        y_data_path = 'RNN_STM_tools/generate_STM_RNN_seqs'
        n_items = 'unknown'

    # # save path
    exp_cond_path = os.path.join(exp_root, exp_name, output_filename)

    train_folder = 'training'

    exp_cond_path = os.path.join(exp_cond_path, train_folder)

    # # check which machine i am on
    if running_on_laptop():
        exp_cond_path = switch_home_dirs(exp_cond_path)

    if not os.path.exists(exp_cond_path):
        # print(f'exp_cond_path: {exp_cond_path}')
        os.makedirs(exp_cond_path)
    os.chdir(exp_cond_path)
    print(f"\nsaving to exp_cond_path: {exp_cond_path}")

    # # The Model

    if train_cycles:
        train_ts = None
    else:
        train_ts = timesteps

    output_type = 'classes'
    if type(output_units) is int:
        n_output_units = output_units
        if n_output_units == x_size:
            output_type = 'letters'
        elif n_output_units == n_cats:
            output_type = 'classes'
        else:
            raise ValueError(
                f"n_output_units does not match x_size or n_cats\n"
                f"need to specifiy output_type as words or letters")

    elif type(output_units) is str:
        if output_units.lower() == 'n_cats':
            n_output_units = n_cats
        elif output_units.lower() == 'x_size':
            n_output_units = x_size
            output_type = 'letters'
        else:
            raise ValueError(
                f'output_units should be specified as an int, '
                f'or with a string "n_cats" or "x_size"\noutput_units: {output_units}'
            )

    if 'rnn' in model_dir:
        print("\nloading a recurrent model")
        augmentation = False

        stateful = False
        if LENS_states != False:
            stateful = True

        print(f"\nserial_recall: {serial_recall}\n"
              f"y_1hot: {y_1hot}\n"
              f"batch_size: {batch_size}\n"
              f"unroll: {unroll}\n"
              f"LENS_states: {LENS_states}\n"
              f"timesteps: {timesteps}")

        models_dict = {
            'Bowers14rnn': Bowers14rnn,
            'SimpleRNNn': SimpleRNNn,
            'Bowers_14_Elman': Bowers_14_Elman,
            'Bowers_14_Elman2': Bowers_14_Elman2,
            'GRUn': GRUn,
            'LSTMn': LSTMn,
            'Seq2Seq': Seq2Seq
        }

        model = models_dict[model_name].build(features=x_size,
                                              classes=n_output_units,
                                              timesteps=train_ts,
                                              batch_size=batch_size,
                                              n_layers=hid_layers,
                                              serial_recall=serial_recall,
                                              units_per_layer=units_per_layer,
                                              act_func=act_func,
                                              y_1hot=y_1hot,
                                              dropout=use_dropout,
                                              masking=train_cycles,
                                              weight_init=weight_init,
                                              init_range=init_range,
                                              unroll=unroll,
                                              stateful=stateful)
    else:
        print("model_dir not recognised")

    # # loss
    loss_func = 'categorical_crossentropy'
    if not y_1hot:
        loss_func = 'binary_crossentropy'

    # # sort optimizers
    # optimizer
    sgd = SGD(lr=lr, momentum=.9)  # decay=sgd_lr / max_epochs)
    this_optimizer = sgd

    if use_optimizer == 'SGD_no_momentum':
        this_optimizer = SGD(lr=lr, momentum=0.0,
                             nesterov=False)  # decay=sgd_lr / max_epochs)
    elif use_optimizer == 'SGD_Nesterov':
        this_optimizer = SGD(lr=lr, momentum=.1,
                             nesterov=True)  # decay=sgd_lr / max_epochs)
    elif use_optimizer == 'SGD_mom_clip':
        this_optimizer = SGD(lr=lr, momentum=.9,
                             clipnorm=1.)  # decay=sgd_lr / max_epochs)
    elif use_optimizer == 'dougs':
        this_optimizer = dougsMomentum(lr=lr, momentum=.9)

    elif use_optimizer == 'adam':
        this_optimizer = Adam(lr=lr, amsgrad=False)
    elif use_optimizer == 'adam_amsgrad':
        # simulations run prior to 05122019 did not have this option, and may have use amsgrad under the name 'adam'
        this_optimizer = Adam(lr=lr, amsgrad=True)

    elif use_optimizer == 'RMSprop':
        this_optimizer = RMSprop(lr=lr)
    elif use_optimizer == 'Adagrad':
        this_optimizer = Adagrad()
    elif use_optimizer == 'Adadelta':
        this_optimizer = Adadelta()
    elif use_optimizer == 'Adamax':
        this_optimizer = Adamax(lr=lr)
    elif use_optimizer == 'Nadam':
        this_optimizer = Nadam()

    # # metrics
    main_metric = 'binary_accuracy'
    if y_1hot:
        main_metric = 'categorical_accuracy'

    model.compile(loss=loss_func,
                  optimizer=this_optimizer,
                  metrics=[main_metric])

    optimizer_details = model.optimizer.get_config()
    # print_nested_round_floats(model_details)
    focussed_dict_print(optimizer_details, 'optimizer_details')

    # # get model dict
    model_info = get_model_dict(model)  # , verbose=True)
    # print("\nmodel_info:")
    print_nested_round_floats(model_info, 'model_info')
    tf.compat.v1.keras.utils.plot_model(model,
                                        to_file=f'{model_name}_diag.png',
                                        show_shapes=True)

    # # call backs and training parameters
    checkpoint_path = f'{output_filename}_model.hdf5'

    checkpoint_mon = 'loss'
    if use_val_data:
        checkpoint_mon = 'val_loss'

    # checkpointing.  Save model and weights with best val loss.
    checkpointer = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor=checkpoint_mon,
        verbose=1,
        save_best_only=True,
        save_weights_only=False,
        mode='auto',
        load_weights_on_restart=True)

    # patience_for_loss_change: wait this long to see if loss improves
    patience_for_loss_change = int(max_epochs / 50)

    if patience_for_loss_change < 5:
        patience_for_loss_change = 5

    # early_stop_plateau - if there is no imporovement
    early_stop_plateau = tf.keras.callbacks.EarlyStopping(
        monitor='loss',
        min_delta=min_loss_change,
        patience=patience_for_loss_change,
        verbose=1,
        mode='min')

    val_early_stop_plateau = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        min_delta=min_loss_change,
        patience=patience_for_loss_change,
        verbose=verbose,
        mode='min')

    date_n_time = int(datetime.datetime.now().strftime("%Y%m%d%H%M"))
    tensorboard_path = os.path.join(exp_cond_path, 'tb', str(date_n_time))

    tensorboard = TensorBoard(log_dir=tensorboard_path)

    print('\n\nto access tensorboard, in terminal use\n'
          f'tensorboard --logdir={tensorboard_path}'
          '\nthen click link'
          '')

    callbacks_list = [early_stop_plateau, checkpointer, tensorboard]
    val_callbacks_list = [val_early_stop_plateau, checkpointer, tensorboard]

    if LENS_states:
        callbacks_list = [
            early_stop_plateau, checkpointer, tensorboard,
            SetHiddedStatesCallback()
        ]

    ############################
    # # train model
    print("\n**** TRAINING ****")
    if augmentation:
        # # construct the image generator for data augmentation
        aug = ImageDataGenerator(
            rotation_range=
            10,  # randomly rotate images in the range (degrees, 0 to 180)
            width_shift_range=
            0.1,  # randomly shift images horizontally (fraction of total width)
            height_shift_range=
            0.1,  # randomly shift images vertically (fraction of total height)
            shear_range=0.1,  # set range for random shear (tilt image)
            zoom_range=0.1,  # set range for random zoom
            # horizontal_flip=True,
            fill_mode="nearest")

        if use_val_data:
            fit_model = model.fit_generator(
                aug.flow(x_train, y_train, batch_size=batch_size),
                validation_data=(x_val, y_val),
                # steps_per_epoch=len(x_train) // batch_size,
                epochs=max_epochs,
                verbose=1,
                callbacks=val_callbacks_list)
        else:
            fit_model = model.fit_generator(
                aug.flow(x_train, y_train, batch_size=batch_size),
                # steps_per_epoch=len(x_train) // batch_size,
                epochs=max_epochs,
                verbose=1,
                callbacks=callbacks_list)

    else:
        if use_val_data:
            fit_model = model.fit(x_train,
                                  y_train,
                                  validation_data=(x_val, y_val),
                                  epochs=max_epochs,
                                  batch_size=batch_size,
                                  verbose=1,
                                  callbacks=val_callbacks_list)
        elif not generator:
            print("Using loaded data")
            print(f"x: {np.shape(x_train)}, y: {np.shape(y_train)}")

            fit_model = model.fit(x_train,
                                  y_train,
                                  epochs=max_epochs,
                                  batch_size=batch_size,
                                  verbose=1,
                                  callbacks=callbacks_list)
        else:
            # # use generator
            print("Using data generator")
            generate_data = generate_STM_RNN_seqs(
                data_dict=data_dict,
                seq_len=timesteps,
                batch_size=batch_size,
                serial_recall=serial_recall,
                output_type=output_type,
                x_data_type=x_data_type,
                end_seq_cue=end_seq_cue,
                train_cycles=train_cycles,
                verbose=False  # verbose
            )

            fit_model = model.fit_generator(generate_data,
                                            steps_per_epoch=100,
                                            epochs=max_epochs,
                                            callbacks=callbacks_list,
                                            shuffle=False)

    ########################################################
    print("\n**** TRAINING COMPLETE ****")

    print(f"\nModel name: {checkpoint_path}")

    # # plot the training loss and accuracy
    fig, (ax1, ax2) = plt.subplots(2, sharex=True)
    ax1.plot(fit_model.history[main_metric])
    if use_val_data:
        ax1.plot(fit_model.history['val_acc'])
    ax1.set_title(f'{main_metric} (top); loss (bottom)')
    ax1.set_ylabel(f'{main_metric}')
    ax1.set_xlabel('epoch')
    ax2.plot(fit_model.history['loss'])
    if use_val_data:
        ax2.plot(fit_model.history['val_loss'])
    ax2.set_ylabel('loss')
    ax2.set_xlabel('epoch')
    fig.legend(['train', 'val'], loc='upper left')
    plt.savefig(str(output_filename) + '_training.png')
    plt.close()

    # # get best epoch number
    if use_val_data:
        # print(fit_model.history['val_loss'])
        trained_for = int(fit_model.history['val_loss'].index(
            min(fit_model.history['val_loss'])))
        end_val_loss = float(fit_model.history['val_loss'][trained_for])
        end_val_acc = float(fit_model.history['val_acc'][trained_for])
    else:
        # print(fit_model.history['loss'])
        trained_for = int(fit_model.history['loss'].index(
            min(fit_model.history['loss'])))
        end_val_loss = np.nan
        end_val_acc = np.nan

    end_loss = float(fit_model.history['loss'][trained_for])
    end_acc = float(fit_model.history[main_metric][trained_for])
    print(f'\nTraining Info\nbest loss after {trained_for} epochs\n'
          f'end loss: {end_loss}\nend acc: {end_acc}\n')
    if use_val_data:
        print(f'end val loss: {end_val_loss}\nend val acc: {end_val_acc}')

    # # # # PART 3 get_scores() # # #
    """accuracy can be two things
    1. What proportion of all sequences are entirely correct
    2. what is the average proportion of each sequence that is correct
    e.g., might get none of the sequences correct but on average get 50% of each sequence correct
    """

    # # load test label seqs
    data_path = data_dict['data_path']

    if not os.path.exists(data_path):
        if os.path.exists(switch_home_dirs(data_path)):
            data_path = switch_home_dirs(data_path)
        else:
            raise FileExistsError(f'data path not found: {data_path}')

    print(f'data_path: {data_path}\n')

    if train_cycles:
        timesteps = 7

    # test_filename = f'seq{timesteps}_v{n_cats}_960_test_seq_labels.npy'
    test_filename = f'seq{timesteps}_v{n_cats}_1per_ts_test_seq_labels.npy'

    test_seq_path = os.path.join(data_path, test_filename)

    # if not os.path.isfile(test_seq_path):
    #     if os.path.isfile(switch_home_dirs(test_seq_path)):
    #         test_seq_path = switch_home_dirs(test_seq_path)
    # if os.path.isfile(test_seq_path):
    test_label_seqs = np.load(test_seq_path)

    print(f'test_label_seqs: {np.shape(test_label_seqs)}\n{test_label_seqs}\n')

    # # call get test accracy(serial_recall,
    scores_dict = get_test_scores(model=model,
                                  data_dict=data_dict,
                                  test_label_seqs=test_label_seqs,
                                  serial_recall=serial_recall,
                                  x_data_type=x_data_type,
                                  output_type=output_type,
                                  end_seq_cue=end_seq_cue,
                                  batch_size=batch_size,
                                  verbose=verbose)

    mean_IoU = scores_dict['mean_IoU']
    prop_seq_corr = scores_dict['prop_seq_corr']

    trained_date = int(datetime.datetime.now().strftime("%y%m%d"))
    trained_time = int(datetime.datetime.now().strftime("%H%M"))
    model_info['overview'] = {
        'model_type': model_dir,
        'model_name': model_name,
        "trained_model": checkpoint_path,
        "hid_layers": hid_layers,
        "units_per_layer": units_per_layer,
        'act_func': act_func,
        "serial_recall": serial_recall,
        "generator": generator,
        "x_data_type": x_data_type,
        "end_seq_cue": end_seq_cue,
        "use_val_data": use_val_data,
        "weight_init": weight_init,
        "optimizer": use_optimizer,
        'learning_rate': lr,
        "loss_func": loss_func,
        "use_batch_norm": use_batch_norm,
        "batch_size": batch_size,
        "augmentation": augmentation,
        "grey_image": grey_image,
        "use_dropout": use_dropout,
        "loss_target": loss_target,
        "min_loss_change": min_loss_change,
        "max_epochs": max_epochs,
        'timesteps': timesteps,
        'unroll': unroll,
        'y_1hot': y_1hot,
        'LENS_states': LENS_states
    }

    git_repository = '/home/nm13850/Documents/PhD/code/library'
    if os.path.isdir('/Users/nickmartin/Documents/PhD/code/library'):
        git_repository = '/Users/nickmartin/Documents/PhD/code/library'

    repo = git.Repo(git_repository)
    # repo = "git.Repo('/home/nm13850/Documents/PhD/code/library')"

    sim_dict_name = f"{output_filename}_sim_dict.txt"

    sim_dict_path = os.path.join(exp_cond_path, sim_dict_name)

    # # simulation_info_dict
    sim_dict = {
        "topic_info": {
            "output_filename": output_filename,
            "cond": cond,
            "run": run,
            "data_dict_path": data_dict_path,
            "model_path": model_path,
            "exp_cond_path": exp_cond_path,
            'exp_name': exp_name,
            'cond_name': cond_name
        },
        "data_info": data_dict,
        "model_info": model_info,
        "training_info": {
            "trained_for": trained_for,
            "loss": end_loss,
            "acc": end_acc,
            'use_val_data': use_val_data,
            "end_val_acc": end_val_acc,
            "end_val_loss": end_val_loss,
            'scores': scores_dict,
            "trained_date": trained_date,
            "trained_time": trained_time,
            'x_data_path': x_data_path,
            'y_data_path': y_data_path,
            'sim_dict_path': sim_dict_path,
            'tensorboard_path': tensorboard_path,
            'commit': repo.head.object.hexsha,
        }
    }

    if not use_val_data:
        sim_dict['training_info']['end_val_acc'] = 'NaN'
        sim_dict['training_info']['end_val_loss'] = 'NaN'

    with open(sim_dict_name, 'w') as fp:
        json.dump(sim_dict, fp, indent=4, separators=(',', ':'))
    """converts lists of units per layer [32, 64, 128] to str "32-64-128".
    Convert these strings back to lists of ints with:
    back_to_ints = [int(i) for i in str_upl.split(sep='-')]
    """
    str_upl = "-".join(
        map(str, model_info['layers']['hid_layers']['hid_totals']['UPL']))
    str_fpl = "-".join(
        map(str, model_info['layers']['hid_layers']['hid_totals']['FPL']))

    # record training info comparrisons
    training_info = [
        output_filename, cond, run, dset_name, x_size, n_cats, timesteps,
        n_items, model_dir, model_name,
        model_info['layers']['totals']['hid_layers'], str_upl,
        model_info['layers']['hid_layers']['hid_totals']['analysable'],
        x_data_type, act_func, serial_recall, weight_init, lr, use_optimizer,
        use_batch_norm, use_dropout, batch_size, augmentation, grey_image,
        use_val_data, loss_target, min_loss_change, max_epochs, trained_for,
        end_acc, end_loss, end_val_acc, end_val_loss, checkpoint_path,
        trained_date, trained_time, mean_IoU, prop_seq_corr, unroll, y_1hot,
        LENS_states
    ]

    # exp_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
    # os.chdir(exp_path)

    # # save sel summary in exp folder not condition folder
    exp_path = find_path_to_dir(long_path=exp_cond_path, target_dir=exp_name)
    os.chdir(exp_path)

    print(f"save_summaries: {exp_path}")

    # check if training_info.csv exists
    if not os.path.isfile(f"{exp_name}_training_summary.csv"):

        headers = [
            "file", "cond", "run", "dataset", "x_size", "n_cats", 'timesteps',
            "n_items", "model_type", "model", 'hid_layers', "UPL",
            "analysable", "x_data_type", "act_func", "serial_recall",
            "weight_init", 'LR', "optimizer", "batch_norm", "dropout",
            "batch_size", "aug", "grey_image", "val_data", "loss_target",
            "min_loss_change", "max_epochs", "trained_for", "end_acc",
            "end_loss", "end_val_acc", "end_val_loss", "model_file", "date",
            "time", "mean_IoU", "prop_seq_corr", "unroll", "y_1hot",
            "LENS_states"
        ]

        training_overview = open(f"{exp_name}_training_summary.csv", 'w')
        mywriter = csv.writer(training_overview)
        mywriter.writerow(headers)
    else:
        training_overview = open(f"{exp_name}_training_summary.csv", 'a')
        mywriter = csv.writer(training_overview)

    mywriter.writerow(training_info)
    training_overview.close()

    if verbose:
        focussed_dict_print(sim_dict, 'sim_dict')

    print('\n\nto access tensorboard, in terminal use\n'
          f'tensorboard --logdir={tensorboard_path}'
          '\nthen click link')

    print("\ntrain_model() finished")

    return sim_dict
Exemple #2
0
def ff_gha(
        sim_dict_path,
        # get_classes=("Conv2D", "Dense", "Activation"),
        gha_incorrect=True,
        use_dataset='train_set',
        save_2d_layers=True,
        save_4d_layers=False,
        exp_root='/home/nm13850/Documents/PhD/python_v2/experiments/',
        verbose=False,
        test_run=False):
    """
    gets activations from hidden units.

    1. load simulation dict (with data info) (*_load_dict.pickle)
        sim_dict can be fed in from sim script, or loaded separately
    2. load model - get structure and details
    3. run dataset through once, recording accuracy per item/class
    4. run on 2nd model to get hid acts

    :param sim_dict_path: path to the dictionary for this experiment condition
    :param get_classes: which types of layer are we interested in?
            I've changed this to just use certain layer names rather than layer classes.
    :param gha_incorrect: GHA for ALL items (True) or just correct items (False)
    :param use_dataset: GHA for train/test data
    :param save_2d_layers: get 1 value per kernel for conv/pool layers
    :param save_4d_layers: keep original shape of conv/pool layers (for other analysis maybe?)
    :param exp_root: root to save experiments
    :param verbose:
    :param test_run: Set test = True to just do one unit per layer

    :return: dict with hid acts per layer.  saved as dict so different shaped arrays don't matter too much
    """

    print('**** ff_gha GHA() ****')

    # # # PART 1 # # #
    # # load details from dict
    if os.path.isfile(sim_dict_path):
        print(f"sim_dict_path: {sim_dict_path}")
        sim_dict = load_dict(sim_dict_path)
        full_exp_cond_path, sim_dict_name = os.path.split(sim_dict_path)

    elif os.path.isfile(os.path.join(exp_root, sim_dict_path)):
        sim_dict_path = os.path.join(exp_root, sim_dict_path)
        print(f"sim_dict_path: {sim_dict_path}")
        sim_dict = load_dict(sim_dict_path)
        full_exp_cond_path, sim_dict_name = os.path.split(sim_dict_path)
    else:
        raise FileNotFoundError(sim_dict_path)

    os.chdir(full_exp_cond_path)
    print(f"set_path to full_exp_cond_path: {full_exp_cond_path}")

    focussed_dict_print(sim_dict, 'sim_dict')

    # # # load datasets
    n_items = sim_dict['data_info']["n_items"]
    n_cats = sim_dict['data_info']["n_cats"]
    hdf5_path = sim_dict['topic_info']["dataset_path"]

    x_data_path = hdf5_path
    y_data_path = '/home/nm13850/Documents/PhD/python_v2/datasets/' \
                  'objects/ILSVRC2012/imagenet_hdf5/y_df.csv'

    # # # data preprocessing
    # # # if network is cnn but data is 2d (e.g., MNIST)
    # if len(np.shape(x_data)) != 4:
    #     if sim_dict['model_info']['overview']['model_type'] == 'cnn':
    #         width, height = sim_dict['data_info']['image_dim']
    #         x_data = x_data.reshape(x_data.shape[0], width, height, 1)
    #         print(f"\nRESHAPING x_data to: {np.shape(x_data)}")

    # Output files
    output_filename = sim_dict["topic_info"]["output_filename"]
    print(f"\nOutput file: {output_filename}")

    # # # # PART 2 # # #
    print("\n**** THE MODEL ****")
    # model_name = sim_dict['model_info']['overview']['trained_model']
    loaded_model = VGG16(weights='imagenet')
    model_details = loaded_model.get_config()
    print_nested_round_floats(model_details)

    n_layers = len(model_details['layers'])
    model_dict = dict()

    # # turn off "trainable" and get useful info
    for layer in range(n_layers):
        # set to not train
        model_details['layers'][layer]['config']['trainable'] = 'False'

        if verbose:
            print(f"Model layer {layer}: {model_details['layers'][layer]}")

        # # get useful info
        layer_dict = {
            'layer': layer,
            'name': model_details['layers'][layer]['config']['name'],
            'class': model_details['layers'][layer]['class_name']
        }

        if 'units' in model_details['layers'][layer]['config']:
            layer_dict['units'] = model_details['layers'][layer]['config'][
                'units']
        if 'activation' in model_details['layers'][layer]['config']:
            layer_dict['act_func'] = model_details['layers'][layer]['config'][
                'activation']
        if 'filters' in model_details['layers'][layer]['config']:
            layer_dict['filters'] = model_details['layers'][layer]['config'][
                'filters']
        if 'kernel_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'kernel_size'][0]
        if 'pool_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'pool_size'][0]
        if 'strides' in model_details['layers'][layer]['config']:
            layer_dict['strides'] = model_details['layers'][layer]['config'][
                'strides'][0]
        if 'rate' in model_details['layers'][layer]['config']:
            layer_dict["rate"] = model_details['layers'][layer]['config'][
                'rate']

        # # set and save layer details
        model_dict[layer] = layer_dict

    # # my model summary
    model_df = pd.DataFrame.from_dict(
        data=model_dict,
        orient='index',
        columns=[
            'layer', 'name', 'class', 'act_func', 'units', 'filters', 'size',
            'strides', 'rate'
        ],
    )

    # # just classes of layers specified in get_layer_list (used to be get classes, not using layer names)
    get_layers_dict = sim_dict['model_info']['VGG16_GHA_layer_dict']
    get_layer_list = [get_layers_dict[key]['name'] for key in get_layers_dict]
    key_layers_df = model_df.loc[model_df['name'].isin(get_layer_list)]

    key_layers_df.reset_index(inplace=True)
    del key_layers_df['index']
    key_layers_df.index.name = 'index'
    key_layers_df = key_layers_df.drop(columns=['size', 'strides', 'rate'])

    # # add column ('n_units_filts')to say how many things needs gha per layer (number of units or filters)
    # # add zeros to rows with no units or filters
    key_layers_df.loc[:, 'n_units_filts'] = key_layers_df.units.fillna(
        0) + key_layers_df.filters.fillna(0)

    print(f"\nkey_layers_df:\n{key_layers_df}")

    key_layers_df.loc[:,
                      "n_units_filts"] = key_layers_df["n_units_filts"].astype(
                          int)

    # # get to total number of units or filters in key layers of the network
    key_n_units_fils = sum(key_layers_df['n_units_filts'])

    print(f"\nkey_layers_df:\n{key_layers_df.head()}")
    print(f"key_n_units_fils: {key_n_units_fils}")
    '''i currently get output layer, make sure I keep this in to make sure I can do class correlation'''

    # # # set dir to save gha stuff # # #
    hid_act_items = 'all'
    if not gha_incorrect:
        hid_act_items = 'correct'

    gha_folder = f'{hid_act_items}_{use_dataset}_gha'

    if test_run:
        gha_folder = os.path.join(gha_folder, 'test')
    gha_path = os.path.join(full_exp_cond_path, gha_folder)

    if not os.path.exists(gha_path):
        os.makedirs(gha_path)
    os.chdir(gha_path)
    print(f"saving hid_acts to: {gha_path}")

    # # # PART 3 get_scores() # # #
    print("\ngetting predicted outputs with hdf_pred_scores()")

    item_correct_df, scores_dict, incorrect_items = hdf_pred_scores(
        model=loaded_model,
        output_filename=output_filename,
        test_run=test_run,
        verbose=verbose)

    if verbose:
        focussed_dict_print(scores_dict, 'Scores_dict')

    # # PART 5
    print("\n**** Get Hidden unit activations ****")
    hid_act_2d_dict = dict(
    )  # # to use to get 2d hid acts (e.g., means from 4d layers)
    hid_act_any_d_dict = dict(
    )  # # to use to get all hid acts (e.g., both 2d and 4d layers)

    # # loop through key layers df
    gha_key_layers = []
    for index, row in key_layers_df.iterrows():
        if test_run:
            if index > 3:
                continue

        layer_number, layer_name, layer_class = row['layer'], row['name'], row[
            'class']
        print(f"\n{layer_number}. name {layer_name} class {layer_class}")

        # if layer_class not in get_classes:  # no longer using this - skip class types not in list
        if layer_name not in get_layer_list:  # skip layers/classes not in list
            continue

        else:
            # print('getting layer')

            hid_acts_dict = hdf_gha(model=loaded_model,
                                    layer_name=layer_name,
                                    layer_class=layer_class,
                                    layer_number=index,
                                    output_filename=output_filename,
                                    gha_incorrect=gha_incorrect,
                                    test_run=test_run,
                                    verbose=verbose)

            hid_act_2d_dict[index] = hid_acts_dict

    print("\n**** saving info to summary page and dictionary ****")
    hid_act_filenames = {'2d': None, 'any_d': None}

    # # # keep these as some analysis scripts will call them for something?
    if save_2d_layers:
        dict_2d_save_name = f'{output_filename}_hid_act_2d.pickle'
        with open(dict_2d_save_name,
                  "wb") as pkl:  # 'wb' mean 'w'rite the file in 'b'inary mode
            pickle.dump(hid_act_2d_dict, pkl)
        # np.save(dict_2d_save_name, hid_act_2d_dict)
        hid_act_filenames['2d'] = dict_2d_save_name

    if save_4d_layers:
        dict_4dsave_name = f'{output_filename}_hid_act_any_d.pickle'
        with open(dict_4dsave_name,
                  "wb") as pkl:  # 'wb' mean 'w'rite the file in 'b'inary mode
            pickle.dump(hid_act_any_d_dict, pkl)
        # np.save(dict_4dsave_name, hid_act_any_d_dict)
        hid_act_filenames['any_d'] = dict_4dsave_name

    cond = sim_dict["topic_info"]["cond"]
    run = sim_dict["topic_info"]["run"]
    if test_run:
        run = 'test'

    hid_units = sim_dict['model_info']['layers']['hid_layers']['hid_totals'][
        'analysable']

    trained_for = sim_dict["training_info"]["trained_for"]
    end_accuracy = sim_dict["training_info"]["acc"]
    dataset = sim_dict["data_info"]["dataset"]
    gha_date = int(datetime.datetime.now().strftime("%y%m%d"))
    gha_time = int(datetime.datetime.now().strftime("%H%M"))

    gha_acc = scores_dict['gha_acc']
    n_cats_correct = scores_dict['n_cats_correct']

    # # GHA_info_dict
    gha_dict_name = f"{output_filename}_GHA_dict.pickle"
    gha_dict_path = os.path.join(gha_path, gha_dict_name)

    gha_dict = {
        "topic_info": sim_dict['topic_info'],
        "data_info": sim_dict['data_info'],
        "model_info": sim_dict['model_info'],
        "training_info": sim_dict['training_info'],
        "GHA_info": {
            "use_dataset": use_dataset,
            'x_data_path': x_data_path,
            'y_data_path': y_data_path,
            'gha_path': gha_path,
            'gha_dict_path': gha_dict_path,
            "gha_incorrect": gha_incorrect,
            "hid_act_files": hid_act_filenames,
            'gha_key_layers': gha_key_layers,
            'key_n_units_fils': key_n_units_fils,
            "gha_date": gha_date,
            "gha_time": gha_time,
            "scores_dict": scores_dict,
            "model_dict": model_dict
        }
    }

    with open(gha_dict_name, "wb") as pickle_out:
        pickle.dump(gha_dict, pickle_out)

    if verbose:
        # focussed_dict_print(gha_dict, 'gha_dict', ['GHA_info', "scores_dict"])
        focussed_dict_print(gha_dict, 'gha_dict', ['GHA_info'])

    # make a list of dict names to do sel on
    if not os.path.isfile(f"{output_filename}_dict_list_for_sel.csv"):
        dict_list = open(f"{output_filename}_dict_list_for_sel.csv", 'w')
        mywriter = csv.writer(dict_list)
    else:
        dict_list = open(f"{output_filename}_dict_list_for_sel.csv", 'a')
        mywriter = csv.writer(dict_list)

    mywriter.writerow([gha_dict_name[:-7]])
    dict_list.close()

    print(f"\nadded to list for selectivity analysis: {gha_dict_name[:-7]}")

    gha_info = [
        cond, run, output_filename, n_layers, hid_units, dataset, use_dataset,
        gha_incorrect, n_cats, trained_for, end_accuracy, gha_acc,
        n_cats_correct, test_run, gha_date, gha_time
    ]

    # # check if gha_summary.csv exists
    # # save summary file in exp folder (grandparent dir to gha folder: exp/cond/gha)
    # to move up to parent just use '..' rather than '../..'

    # exp_name = exp_dir.strip('/')
    exp_name = sim_dict['topic_info']['exp_name']

    os.chdir('../..')
    exp_path = os.getcwd()

    if not os.path.isfile(exp_name + "_GHA_summary.csv"):
        gha_summary = open(exp_name + "_GHA_summary.csv", 'w')
        mywriter = csv.writer(gha_summary)
        summary_headers = [
            "cond", "run", 'filename', "n_layers", "hid_units", "dataset",
            "GHA_on", 'incorrect', "n_cats", "trained_for", "train_acc",
            "gha_acc", 'n_cats_correct'
            "test_run", "gha_date", "gha_time"
        ]

        mywriter.writerow(summary_headers)
        print(f"creating summary csv at: {exp_path}")

    else:
        gha_summary = open(exp_name + "_GHA_summary.csv", 'a')
        mywriter = csv.writer(gha_summary)
        print(f"appending to summary csv at: {exp_path}")

    mywriter.writerow(gha_info)
    gha_summary.close()

    print("\nend of ff_gha")

    return gha_info, gha_dict
Exemple #3
0
def hdf_pred_scores(
    model,
    output_filename,
    data_hdf_path='/home/nm13850/Documents/PhD/python_v2/datasets/'
    'objects/ILSVRC2012/imagenet_hdf5/imageNet2012Val.h5',
    total_items=50000,
    batch_size=16,
    x_path='x_data',
    y_df_path='y_df',
    use_vgg_colours=True,
    df_name='item_correct_df',
    test_run=False,
    verbose=False,
):
    """
    Script to get predictions from slices of X data on a model.
    For each slice, then also get the item correct on this slice.

    save pred_out, item correct etc into new hdf file as I go.

    keep a counter of n_items and n_correct to give accuracy
    or just get acc on the whole thing on the end.

    :param model:
    :param output_filename:
    :param data_hdf_path: path to hdf5 file
    :param total_items: all items in dataset
    :param batch_size:
    :param x_path: on hdf5 file
    :param y_df_path: on hdf5 file (note if made with Pandas, it might also need ['table']
    :param test_run: If True, don't run whole dataset, just first 64 items
    :param use_vgg_colours: preprocess RBG to BRG

    :param verbose: If True, print details to screen

    :return:
    """

    if test_run:
        total_items = 64

    batchsize = batch_size
    batches_in_data = total_items // batchsize

    # # list of all incorrect items added to slice-by-slice
    incorrect_items = []

    for i in range(batches_in_data):
        # # step through the data in slices/batches

        with h5py.File(data_hdf_path, 'r') as dataset:
            # # open the hdf with the x and y data

            # # get indices to slice to/from
            idx_from = i * batchsize
            idx_to = (i + 1) * batchsize
            print(f"\n{i}: from {idx_from} to: {idx_to}")

            # # slice x_data
            x_data = dataset[x_path][idx_from:idx_to, ...]

            # # preprocess colours from RGB to BGR
            if use_vgg_colours:
                x_data = preprocess_input(x_data)

            # # slice y data
            y_df_tuples = dataset[y_df_path]['table'][idx_from:idx_to, ...]
            # convert list fo tuples to list of lists
            y_df_lists = [list(elem) for elem in y_df_tuples]

            y_df = pd.DataFrame(
                y_df_lists, columns=['item', 'cat', 'filename', 'class_name'])
            y_df = y_df.set_index('item')

            if verbose:
                print(f"x_data: {x_data.shape}")
                print(f"y_df: {y_df.shape}")
                print(f"y_df: {y_df}")

            # yield x_data, y_df

            # # get the true cat labels for this slice
            true_cat = [int(i) for i in y_df['cat'].to_numpy()]

            # # get predictions (per cat) and then pred_labels
            pred_vals = model.predict(x_data)
            pred_cat = np.argmax(pred_vals, axis=1)

            # # # get item correct and scores (per cat and total)
            n_items, n_cats = np.shape(pred_vals)

            slice_incorrect_items = [
                x for x, y in zip(pred_cat, true_cat) if x != y
            ]
            incorrect_items.extend(slice_incorrect_items)
            item_score = [
                1 if x == y else 0 for x, y in zip(pred_cat, true_cat)
            ]

            # # append item correct to new hdf file
            item_correct_df = y_df  # .copy()

            # # convert troublesome columns to string
            item_correct_df["filename"] = item_correct_df["filename"].map(str)
            item_correct_df["class_name"] = item_correct_df["class_name"].map(
                str)

            # # add item_correct column ['full)_model] to item_correct df
            item_correct_df.insert(2, column="full_model", value=item_score)

            if verbose:
                print("item_correct_df.shape: {}".format(
                    item_correct_df.shape))
                print("len(item_score): {}".format(len(item_score)))
                print(item_correct_df.dtypes)
                # print(item_correct_df.head())

            # # make output hdf to store item_correct df
            with pd.HDFStore(f"{output_filename}_gha.h5") as store:

                line_len_dict = {
                    # 'item': 0,
                    'cat': 0,
                    'full_model': 0,
                    'filename': 35,
                    'class_name': 35
                }

                print(store.keys())
                if f"/{df_name}" not in store.keys():
                    print(f"creating blank df in {df_name} on store")

                    store.put(f'{df_name}',
                              pd.DataFrame(data=None,
                                           columns=[
                                               'item', 'cat', 'full_model',
                                               'filename', 'class_name'
                                           ]),
                              format='t',
                              append=True,
                              min_itemsize=line_len_dict)

                # # I'm having problems with line length
                # # trying to work out why
                line_len_check_dict = {}
                for c in item_correct_df:
                    if item_correct_df[c].dtype == 'object':
                        max_len = item_correct_df[c].map(len).max()
                        print(f'Max length of column {c}: {max_len}')
                        line_len_check_dict[c] = max_len
                    else:
                        max_len = 0
                        print(f'Not a string column {c}: {max_len}')
                        line_len_check_dict[c] = max_len

                line_lengths = list(line_len_check_dict.values())
                max_line = max(line_lengths)

                if max_line > 30:
                    focussed_dict_print(line_len_check_dict,
                                        'line_len_check_dict')

                store.append(f'/{df_name}',
                             item_correct_df,
                             min_itemsize=line_len_dict)

                if verbose:
                    print(
                        f"store['item_correct_df'].shape: {store[f'/{df_name}'].shape}"
                    )

    print("\nfinished looping through dataset")
    incorrect_items_s = pd.Series(incorrect_items)

    # add incorrect items to output hdf
    incorrect_items_name = 'incorrect_items'
    if df_name != 'item_correct_df':
        incorrect_items_name = f'incorrect_{df_name}'

    with pd.HDFStore(f"{output_filename}_gha.h5") as store:

        store.put(incorrect_items_name, incorrect_items_s)

        print(f"store.keys(): {store.keys()}")

        # if df_name in store.keys():
        #     pass
        # elif f"/{df_name}" in store.keys():
        #     df_name = f'/{df_name}'

        item_correct_df = store[f'/{df_name}']

    print(f"item_correct_df.shape: {item_correct_df.shape}")
    # print(item_correct_df.head())

    full_model = [int(i) for i in item_correct_df['full_model'].to_numpy()]
    fm_correct = np.sum(full_model)
    fm_items = len(full_model)

    gha_acc = np.around(fm_correct / fm_items, decimals=3)

    print("\nitems: {}\ncorrect: {}\nincorrect: {}\naccuracy: {}".format(
        fm_items, fm_correct, fm_items - fm_correct, gha_acc))

    # # get count_correct_per_class
    corr_per_cat_dict = dict()
    for cat in range(n_cats):
        corr_per_cat_dict[cat] = len(
            item_correct_df[(item_correct_df['cat'] == cat)
                            & (item_correct_df['full_model'] == 1)])

    # # # are any categories missing?
    category_fail = sum(value == 0 for value in corr_per_cat_dict.values())
    category_low = sum(value < 3 for value in corr_per_cat_dict.values())
    n_cats_correct = n_cats - category_fail

    scores_dict = {
        "n_items": fm_items,
        "n_correct": fm_correct,
        "gha_acc": gha_acc,
        "category_fail": category_fail,
        "category_low": category_low,
        "n_cats_correct": n_cats_correct,
        "corr_per_cat_dict": corr_per_cat_dict,
        # "item_correct_name": item_correct_name,
        # "flat_conf_name": flat_conf_name,
        "scores_date": tools_date,
        'scores_time': tools_time
    }

    return item_correct_df, scores_dict, incorrect_items
Exemple #4
0
def lesion_study(gha_dict_path,
                 get_classes=("Conv2D", "Dense", "Activation"),
                 verbose=False,
                 test_run=False):
    """
        lesion study
    1. load dict from study (should run with sim, GHA or sel dict)
    2. from dict get x, y, num of items, IPC etc
    3. load original model and weights (model from, num hid units, num hid outputs)
    4. run get scores on ALL items - record total acc, class acc, item sucess
    5. loop through:
        lesion unit (inputs, bias, outputs)
        test on ALL items - record total acc, class acc, item success (pass/fail)
    6. output:
        overall acc change per unit
        class acc change per unit
        item success per unit (pass, l_fail, l_pass, fail) l_fail if pass on full network, fail when lesioned

    :param gha_dict_path: path to GHA dict - ideally should work for be GHA, sel or sim
    :param get_classes: which types of layer are we interested in?
    :param verbose: will print less if false, otherwise will print eveything
    :param test_run: just run a few units for a test, print lots of output

    :return: lesion_dict:   lesion_path: path to dir where everything is saved,
                            loaded_dict: name of lesion dict, 
                            x_data_path, y_data_path: paths to data used, 
                            key_layer_classes: classes of layers lesioned,
                            key_lesion_layers_list: layer names of lesioned layers (excludes output)
                            total_units_filts: total number of lesionable units
                            lesion_highlights: dict with biggest total and class increase and decrease per layer and 
                                                for the whole model
                            lesion_means_dict: 'mean total change' and 'mean max class drop' per layer

    """

    print('\n**** lesion_21052019 lesion_study() ****')

    # # # chdir to this folder
    full_exp_cond_gha_path, gha_dict_name = os.path.split(gha_dict_path)
    training_dir, _ = os.path.split(full_exp_cond_gha_path)
    if not os.path.exists(full_exp_cond_gha_path):
        print("ERROR - path for this experiment not found")
    os.chdir(full_exp_cond_gha_path)
    print(f"set_path to full_exp_cond_gha_path: {full_exp_cond_gha_path}")

    # # # PART 1 # # #
    # # load details from dict
    if type(gha_dict_path) is str:
        gha_dict = load_dict(gha_dict_path)
    focussed_dict_print(gha_dict, 'gha_dict')

    # # # load datasets
    # use_dataset = gha_dict['GHA_info']['use_dataset']

    # # check for training data
    n_items = gha_dict["data_info"]["n_items"]
    items_per_cat = gha_dict["data_info"]["items_per_cat"]
    x_data_path = gha_dict["data_info"]["dataset_path"]
    y_data_path = gha_dict["data_info"]["dataset_path"]

    # if use_dataset in gha_dict['data_info']:
    #     x_data_path = os.path.join(gha_dict['data_info']['data_path'], gha_dict['data_info'][use_dataset]['X_data'])
    #     y_data_path = os.path.join(gha_dict['data_info']['data_path'], gha_dict['data_info'][use_dataset]['Y_labels'])
    #     n_items = gha_dict["data_info"][use_dataset]["n_items"]
    #     items_per_cat = gha_dict["data_info"][use_dataset]["items_per_cat"]
    # else:
    #     x_data_path = os.path.join(gha_dict['data_info']['data_path'], gha_dict['data_info']['X_data'])
    #     y_data_path = os.path.join(gha_dict['data_info']['data_path'], gha_dict['data_info']['Y_labels'])
    #     n_items = gha_dict["data_info"]["n_items"]
    #     items_per_cat = gha_dict["data_info"]["items_per_cat"]

    n_cats = gha_dict['data_info']["n_cats"]
    if type(items_per_cat) is int:
        items_per_cat = dict(zip(list(range(n_cats)),
                                 [items_per_cat] * n_cats))

    if gha_dict['GHA_info']['gha_incorrect'] == 'False':
        # # only gha for correct items
        n_items = gha_dict['GHA_info']['scores_dict']['n_correct']
        items_per_cat = gha_dict['GHA_info']['scores_dict'][
            'corr_per_cat_dict']

    # x_data = load_x_data(x_data_path)
    # y_df, y_label_list = load_y_data(y_data_path)
    #
    # if verbose is True:
    #     print(f"y_df: {y_df.shape}\n{y_df.head()}\n"
    #           f"y_df dtypes: {y_df.dtypes}\n"
    #           f"y_label_list:\n{y_label_list[:10]}")
    #
    # # # data preprocessing
    # # # if network is cnn but data is 2d (e.g., MNIST)
    # if len(np.shape(x_data)) != 4:
    #     if gha_dict['model_info']['overview']['model_type'] == 'cnn':
    #         width, height = gha_dict['data_info']['image_dim']
    #         x_data = x_data.reshape(x_data.shape[0], width, height, 1)
    #         print(f"\nRESHAPING x_data to: {np.shape(x_data)}")

    output_filename = gha_dict["topic_info"]["output_filename"]
    print(f"\nOutput file: {output_filename}")

    # # set up dicts to save stuff
    # # # set dir to save lesion stuff stuff # # #
    lesion_path = os.path.join(os.getcwd(), 'lesion')
    if test_run is True:
        lesion_path = os.path.join(lesion_path, 'test')
    if not os.path.exists(lesion_path):
        os.makedirs(lesion_path)
    os.chdir(lesion_path)
    print(f"\nsaving lesion data to: {lesion_path}")

    # count_per_cat_dict is for storing n_items_correct for the lesion study
    count_p_cat_dict_name = f"{lesion_path}/{output_filename}_count_p_cat_dict.pickle"
    if not os.path.isfile(count_p_cat_dict_name):
        count_per_cat_dict = dict()
        with open(count_p_cat_dict_name, "wb") as pickle_out:
            pickle.dump(count_per_cat_dict, pickle_out)

    # prop_change dict - to compare with Zhou_2018
    prop_change_dict_name = f"{lesion_path}/{output_filename}_prop_change_dict.pickle"
    if not os.path.isfile(prop_change_dict_name):
        prop_change_dict = dict()
        with open(prop_change_dict_name, "wb") as pickle_out:
            pickle.dump(prop_change_dict, pickle_out)

    # #  item change dict, stores per item whether it changed (correct/incorrect) when unit lesioned
    item_change_dict_name = f"{lesion_path}/{output_filename}_item_change_dict.pickle"
    if not os.path.isfile(item_change_dict_name):
        item_change_dict = dict()
        with open(item_change_dict_name, "wb") as pickle_out:
            pickle.dump(item_change_dict, pickle_out)

    # # lesioning highlights dict has biggest total and per class change
    highlights_dict_name = f"{lesion_path}/{output_filename}_les_highlights.pickle"
    if not os.path.isfile(highlights_dict_name):
        lesion_highlights_dict = dict()
        lesion_highlights_dict["highlights"] = {
            "total_increase": ("None", 0),
            "total_decrease": ("None", 0),
            "class_increase": ("None", 0),
            "class_decrease": ("None", 0)
        }
        with open(highlights_dict_name, "wb") as pickle_out:
            pickle.dump(lesion_highlights_dict, pickle_out)

    # # mean scores per (layer?)
    les_means_dict_name = f"{lesion_path}/{output_filename}_les_means_dict.pickle"
    if not os.path.isfile(les_means_dict_name):
        lesion_means_dict = dict()
        with open(les_means_dict_name, "wb") as pickle_out:
            pickle.dump(lesion_means_dict, pickle_out)

    # # # # PART 2 # # #
    print("\n**** load original trained MODEL ****")
    model_architecture_name = gha_dict['model_info']['overview']['model_name']
    trained_model_name = gha_dict['model_info']['overview']['trained_model']

    optimizer = gha_dict['model_info']['overview']['optimizer']

    print(f"model_architecture_name: {model_architecture_name}")
    if model_architecture_name == 'VGG16':
        original_model = VGG16(weights='imagenet')
        # x_data = preprocess_input(x_data)  # preprocess the inputs loaded as RGB to BGR
    else:
        model_path = os.path.join(training_dir, trained_model_name)
        original_model = load_model(model_path)

    if verbose is True:
        print(f"original_model.summary: {original_model.summary()}")

    model_details = original_model.get_config()
    print_nested_round_floats(model_details, 'model_details')

    n_layers = len(model_details['layers'])
    model_dict = dict()

    weights_layer_counter = 0

    # # turn off "trainable" and get useful info
    for layer in range(n_layers):
        # set to not train
        model_details['layers'][layer]['config']['trainable'] = False

        if verbose is True:
            print(f"Model layer {layer}: {model_details['layers'][layer]}")

        # # get useful info
        layer_dict = {
            'layer': layer,
            'name': model_details['layers'][layer]['config']['name'],
            'class': model_details['layers'][layer]['class_name']
        }

        if 'units' in model_details['layers'][layer]['config']:
            layer_dict['units'] = int(
                model_details['layers'][layer]['config']['units'])
        if 'activation' in model_details['layers'][layer]['config']:
            layer_dict['act_func'] = model_details['layers'][layer]['config'][
                'activation']
        if 'filters' in model_details['layers'][layer]['config']:
            layer_dict['filters'] = int(
                model_details['layers'][layer]['config']['filters'])
        if 'kernel_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'kernel_size'][0]
        if 'pool_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'pool_size'][0]
        if 'strides' in model_details['layers'][layer]['config']:
            layer_dict['strides'] = model_details['layers'][layer]['config'][
                'strides'][0]
        if 'rate' in model_details['layers'][layer]['config']:
            layer_dict["rate"] = model_details['layers'][layer]['config'][
                'rate']

        # # record which layers of the weights matrix apply to this layer
        if layer_dict['class'] in ["Conv2D", 'Dense']:
            layer_dict["weights_layer"] = [
                weights_layer_counter, weights_layer_counter + 1
            ]
            weights_layer_counter += 2  # weights and biases
        elif layer_dict[
                'class'] is 'BatchNormalization':  # BN weights: [gamma, beta, mean, std]
            layer_dict["weights_layer"] = [
                weights_layer_counter, weights_layer_counter + 1,
                weights_layer_counter + 2, weights_layer_counter + 3
            ]
            weights_layer_counter += 4
        elif layer_dict['class'] in [
                "Dropout", 'Activation', 'MaxPooling2D', "Flatten"
        ]:
            layer_dict["weights_layer"] = []

        # # set and save layer details
        model_dict[layer] = layer_dict

    # # my model summary
    model_df = pd.DataFrame.from_dict(data=model_dict,
                                      orient='index',
                                      columns=[
                                          'layer', 'name', 'class', 'act_func',
                                          'units', 'filters', 'size',
                                          'strides', 'rate', 'weights_layer'
                                      ])

    key_layers_df = model_df.loc[model_df['class'].isin(get_classes)]
    key_layers_df = key_layers_df.drop(columns=['size', 'strides', 'rate'])

    # # just classes of layers specified in get_classes
    if 'VGG16_GHA_layer_dict' in gha_dict['model_info']:
        get_layers_dict = gha_dict['model_info']['VGG16_GHA_layer_dict']
        print(f'get_layers_dict: {get_layers_dict}')
        get_layer_names = []
        for k, v in get_layers_dict.items():
            if v['name'] != 'predictions':
                get_layer_names.append(v['name'])
        key_layers_df = model_df.loc[model_df['name'].isin(get_layer_names)]

    # # add column ('n_units_filts')to say how many things need lesioning per layer (number of units or filters)
    # # add zeros to rows with no units or filters
    key_layers_df.loc[:, 'n_units_filts'] = key_layers_df.units.fillna(
        0) + key_layers_df.filters.fillna(0)

    key_lesion_layers_list = key_layers_df['name'].to_list()

    # # remove output layers from key layers list
    if any("utput" in s for s in key_lesion_layers_list):
        output_layers = [s for s in key_lesion_layers_list if "utput" in s]
        output_idx = []
        for out_layer in output_layers:
            output_idx.append(key_lesion_layers_list.index(out_layer))
        min_out_idx = min(output_idx)
        key_lesion_layers_list = key_lesion_layers_list[:min_out_idx]
        key_layers_df = key_layers_df.loc[~key_layers_df['name'].
                                          isin(output_layers)]

    if any("predictions" in s for s in key_lesion_layers_list):
        output_layers = [
            s for s in key_lesion_layers_list if "predictions" in s
        ]
        output_idx = []
        for out_layer in output_layers:
            output_idx.append(key_lesion_layers_list.index(out_layer))
        min_out_idx = min(output_idx)
        key_lesion_layers_list = key_lesion_layers_list[:min_out_idx]
        key_layers_df = key_layers_df.loc[~key_layers_df['name'].
                                          isin(output_layers)]

    key_layers_df.reset_index(inplace=True)

    total_units_filts = key_layers_df['n_units_filts'].sum()

    if verbose is True:
        print(
            f"\nmodel_df:\n{model_df}\n"
            f"{len(key_lesion_layers_list)} key_lesion_layers_list: {key_lesion_layers_list}"
        )

    print(f"\nkey_layers_df:\n{key_layers_df}")

    # get original values for weights
    print("\n**** load trained weights ****")
    full_weights = original_model.get_weights()
    n_weight_arrays = np.shape(full_weights)[0]
    print(f"n_weight_arrays: {n_weight_arrays}")

    if test_run is True:
        if verbose is True:
            print(f"full_weights: {np.shape(full_weights)}")
            # for w_array in range(n_weight_arrays):
            #     print(f"\nfull_weights{w_array}: {np.shape(full_weights[w_array])}\n"
            #           f"{full_weights[w_array]}")

    original_model.compile(loss="categorical_crossentropy",
                           optimizer=optimizer,
                           metrics=['accuracy'])
    print(
        f"\nLoaded '{model_architecture_name}' model with original weights: {trained_model_name}"
    )
    """# # save this in case I come back ina few months and need a refersher.
    # no all network layers have associated weights (only learnable ones)
    # some layers have multiple layers of weights associated.
    # conv2d layers have 2 sets of arrays [connection weights, biases] 
    # dense layers have 2 sets of arrays [connection weights, biases]
    # batch_norm has 4 sets of weights [gamma, beta, running_mean, running_std]
    # print("Figuring out layers")
    # print("weights shape {}".format(trained_model_name))
    # for index, layer in enumerate(original_model.layers):
    #     g = layer.get_config()
    #     h = layer.get_weights()
    #     if h:
    #         q = 'has shape'
    #         s = len(h)
    #     if not h:
    #         q = 'not h'
    #         s = (0, )
    #     print("\n{}. {}  {}\n{}\n".format(index, s, g, h))"""

    # # 4. run get scores on ALL items - record total acc, class acc, item sucess

    # # count_per_cat_dict is for storing n_items_correct for the lesion study
    with open(count_p_cat_dict_name, "rb") as pickle_load:
        count_per_cat_dict = pickle.load(pickle_load)

        count_per_cat_dict['dataset'] = items_per_cat
        count_per_cat_dict['dataset']['total'] = n_items

        with open(count_p_cat_dict_name, "wb") as pickle_out:
            pickle.dump(count_per_cat_dict, pickle_out)

    print("\n**** Get original model scores ****")
    # predicted_outputs = original_model.predict(x_data)
    #
    # if model_architecture_name == 'VGG16':
    #     item_correct_df, scores_dict, incorrect_items = VGG_get_scores(predicted_outputs, y_df, output_filename,
    #                                                                    save_all_csvs=True)
    # else:
    #     item_correct_df, scores_dict, incorrect_items = get_scores(predicted_outputs, y_df, output_filename,
    #                                                                save_all_csvs=False, return_flat_conf=True)

    item_correct_df, scores_dict, incorrect_items = hdf_pred_scores(
        model=original_model,
        output_filename=output_filename,
        df_name='full_model',
        test_run=test_run,
        verbose=verbose)

    if verbose is True:
        focussed_dict_print(scores_dict, 'Scores_dict')

    # # # get scores per class for full model
    full_model_CPC = scores_dict['corr_per_cat_dict']
    full_model_CPC['total'] = scores_dict['n_correct']

    with open(count_p_cat_dict_name, "rb") as pickle_load:
        count_per_cat_dict = pickle.load(pickle_load)
        count_per_cat_dict['full_model'] = full_model_CPC
        with open(count_p_cat_dict_name, "wb") as pickle_out:
            pickle.dump(count_per_cat_dict, pickle_out)

    # # use these (unlesioned) pages as the basis for LAYER pages
    # flat_conf_MASTER = scores_dict.loc[:, 'flat_conf']
    if model_architecture_name != 'VGG16':
        flat_conf_MASTER = scores_dict['flat_conf']

    # # save item_correct page.
    item_correct_MASTER = copy.copy(item_correct_df)

    # # # PART 5 # # #
    # # loop through key layers df
    # #     lesion unit (inputs, bias, outputs)
    # #     test on ALL items - record total acc, class acc, item success (pass/fail)
    # print("\n'BEFORE - full_weights'{} {}\n{}\n\n".format(np.shape(full_weights), type(full_weights), full_weights))

    print("\n\n\n**** loop through key layers df ****")
    for index, row in key_layers_df.iterrows():

        if test_run is True:
            if index > 3:
                print(
                    f"\tskip this layer!: test_run, only running subset of layers"
                )
                continue

        layer_number, layer_name, layer_class, n_units_filts = \
            row['layer'], row['name'], row['class'], row['n_units_filts']
        print(
            f"\n{layer_number}. name {layer_name}, class {layer_class}, n_units_filts {n_units_filts}"
        )

        if layer_class not in get_classes:  # no longer using this - skip class types not in list
            # if layer_name not in get_layer_list:  # skip layers/classes not in list
            print(f"\tskip this layer!: {layer_class} not in {get_classes}")
            continue

        # # load places to save layer details
        with open(count_p_cat_dict_name, "rb") as pickle_load:
            count_per_cat_dict = pickle.load(pickle_load)
            count_per_cat_dict[layer_name] = dict()
            with open(count_p_cat_dict_name, "wb") as pickle_out:
                pickle.dump(count_per_cat_dict, pickle_out)

        with open(prop_change_dict_name, "rb") as pickle_load:
            # read dict as it is so far
            prop_change_dict = pickle.load(pickle_load)
            prop_change_dict[layer_name] = dict()
            with open(prop_change_dict_name, "wb") as pickle_out:
                pickle.dump(prop_change_dict, pickle_out)

        with open(les_means_dict_name, "rb") as pickle_load:
            lesion_means_dict = pickle.load(pickle_load)
            lesion_means_dict[layer_name] = dict()
            with open(les_means_dict_name, "wb") as pickle_out:
                pickle.dump(lesion_means_dict, pickle_out)

        item_correct_LAYER = copy.copy(item_correct_MASTER)

        layer_total_change_list = []
        layer_max_drop_list = []

        if model_architecture_name != 'VGG16':
            flat_conf_LAYER = copy.copy(flat_conf_MASTER)

        weights_n_biases = row['weights_layer']
        print(f"weights_n_biases: {weights_n_biases}")

        if not weights_n_biases:  # if empty list []
            print("skip this")
            continue

        weights_layer = weights_n_biases[0]
        biases_layer = weights_n_biases[1]

        for unit in range(int(n_units_filts)):

            if test_run is True:
                if unit > 3:
                    continue

            layer_and_unit = f"{layer_name}_{unit}"
            print(
                f"\n\n**** lesioning layer {layer_number}. ({layer_class}) {layer_and_unit} of {int(n_units_filts)}****"
            )
            # # load original weights each time
            edit_full_weights = copy.deepcopy(full_weights)

            # print("type(edit_full_weights): {}".format(type(edit_full_weights)))
            # print("np.shape(edit_full_weights): {}".format(np.shape(edit_full_weights)))
            # print("np.shape(edit_full_weights[0]): {}".format(np.shape(edit_full_weights[0])))
            # print("888888888888\n\t88888888888888888\n\t\t8888888888\n\t\t\t\t888888")

            # # change input to hid unit
            if layer_class is 'Conv2D':
                edit_full_weights[weights_layer][:, :, :, unit] = 0.0
                '''eg: for 20 conv filters of shape 5 x 5, layer shape (5, 5, 1, 20)'''
            else:
                edit_full_weights[weights_layer][:, unit] = 0.0
            # # change unit bias(index with layer*2 + 1)
            edit_full_weights[biases_layer][unit] = 0.0

            # if test_run is True:
            #     print(f"\n'AFTER l{layer}h{unit}")
            #     for array in range(n_weight_arrays):
            #         print(edit_full_weights[array])

            original_model.set_weights(edit_full_weights)
            original_model.compile(loss="categorical_crossentropy",
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

            # # # get scores
            # predicted_outputs = original_model.predict(x_data)
            #
            # if model_architecture_name == 'VGG16':
            #     item_correct_df, scores_dict, incorrect_items = VGG_get_scores(predicted_outputs, y_df, output_filename,
            #                                                                    save_all_csvs=True)
            # else:
            #     item_correct_df, scores_dict, incorrect_items = get_scores(predicted_outputs, y_df, output_filename,
            #                                                                save_all_csvs=False, return_flat_conf=True)

            item_correct_df, scores_dict, incorrect_items = hdf_pred_scores(
                model=original_model,
                output_filename=output_filename,
                df_name=layer_and_unit,
                test_run=test_run,
                verbose=verbose)
            if verbose is True:
                focussed_dict_print(scores_dict, 'scores_dict')
                print(item_correct_df.head())

            # # # get scores per class for this layer
            corr_per_cat_dict = scores_dict['corr_per_cat_dict']
            corr_per_cat_dict['total'] = scores_dict['n_correct']

            with open(count_p_cat_dict_name, "rb") as pickle_load:
                count_per_cat_dict = pickle.load(pickle_load)
                count_per_cat_dict[layer_name][unit] = corr_per_cat_dict
                with open(count_p_cat_dict_name, "wb") as pickle_out:
                    pickle.dump(count_per_cat_dict, pickle_out)

            item_correct_LAYER[layer_and_unit] = item_correct_df['full_model']
            # item_correct_LAYER.to_csv(f"{output_filename}_{layer_name}_item_correct.csv", index=False)
            nick_to_csv(item_correct_LAYER,
                        f"{output_filename}_{layer_name}_item_correct.csv")

            if model_architecture_name != 'VGG16':
                flat_conf_LAYER[layer_and_unit] = scores_dict['flat_conf'][
                    'full_model']
                # flat_conf_LAYER.to_csv("{}_{}_flat_conf.csv".format(output_filename, layer_name))
                nick_to_csv(flat_conf_LAYER,
                            f"{output_filename}_{layer_name}_flat_conf.csv")

            # # make item change per laye df
            """# # four possible states
                full model      after_lesion    code
            1.  1 (correct)     0 (wrong)       -1
            2.  0 (wrong)       0 (wrong)       0
            3.  1 (correct)     1 (correct)     1
            4.  0 (wrong)       1 (correct)     2

            """
            # # get column names/unit ids
            all_column_list = list(item_correct_LAYER)
            column_list = all_column_list[3:]
            # print(len(column_list), column_list)

            # # get columns to be used in new df
            full_model = item_correct_LAYER['full_model'].to_list()
            item_id = item_correct_LAYER.index.to_list()
            classes = item_correct_LAYER['cat'].to_list()
            # print("full model: {}\n{}".format(len(full_model), full_model))

            # # set up dict to make new df
            with open(item_change_dict_name, "rb") as pickle_load:
                # read dict as it is so far
                item_change_dict = pickle.load(pickle_load)
            item_change_dict[layer_name] = dict()
            item_change_dict[layer_name]['item'] = item_id
            item_change_dict[layer_name]['cat'] = classes
            item_change_dict[layer_name]['full_model'] = full_model

            # # loop through old item_correct_LAYER
            for idx in column_list:
                item_change_list = []
                for item in item_correct_LAYER[idx].iteritems():
                    index, lesnd_acc = item
                    unlesnd_acc = full_model[index]

                    # # four possible states
                    if unlesnd_acc == 1 and lesnd_acc == 0:  # lesioning causes failure
                        item_change = -1
                    elif unlesnd_acc == 0 and lesnd_acc == 0:  # no effect of lesioning, still incorrect
                        item_change = 0
                    elif unlesnd_acc == 1 and lesnd_acc == 1:  # no effect of lesioning, still correct
                        item_change = 1
                    elif unlesnd_acc == 0 and lesnd_acc == 1:  # lesioning causes failed item to pass
                        item_change = 2
                    else:
                        item_change = 'ERROR'
                    # print("{}. unlesnd_acc: {} lesnd_acc: {} item_change: {}".format(index, unlesnd_acc,
                    #                                                                  lesnd_acc, item_change))
                    item_change_list.append(item_change)
                item_change_dict[layer_name][idx] = item_change_list

            with open(item_change_dict_name, "wb") as pickle_out:
                pickle.dump(item_change_dict, pickle_out)

            # # # get class_change scores for this layer
            print("\tget class_change scores:")
            # proportion change = (after_lesion/unlesioned) - 1
            unit_prop_change_dict = dict()
            for (fk, fv), (k2, v2) in zip(full_model_CPC.items(),
                                          corr_per_cat_dict.items()):
                if fv == 0:
                    prop_change = -1
                else:
                    prop_change = (v2 / fv) - 1
                unit_prop_change_dict[fk] = prop_change
                # print(fk, 'v2: ', v2, '/ fv: ', fv, '= pc: ', prop_change)

            with open(prop_change_dict_name, "rb") as pickle_load:
                prop_change_dict = pickle.load(pickle_load)

                prop_change_dict[layer_name][unit] = unit_prop_change_dict

                with open(prop_change_dict_name, "wb") as pickle_out:
                    pickle.dump(prop_change_dict, pickle_out)

            # # tuning
            """Rather than correlating max_class_drop with selectivity,
            #  I need a measure of how selectively impaired the unit was after lesioning.
            #  tuning = max_class_drop / sum(all_classes_that_dropped)"""
            # focussed_dict_print('unit_prop_change_dict', unit_prop_change_dict)

            # todo: Rather than correlating max_class_drop with selectivity,
            #  I need a measure of how selectively impaired the unit was after lesioning.
            #  One option is to use max_class_drop / sum(all_classes_that_dropped)

            # todo:  I need something about how often the most selective class was the max_class_drop class

            # # get layer means
            layer_total_change_list.append(unit_prop_change_dict['total'])
            layer_max_drop_list.append(
                min(list(unit_prop_change_dict.values())[:-1]))

        with open(les_means_dict_name, "rb") as pickle_load:
            lesion_means_dict = pickle.load(pickle_load)
            lesion_means_dict[layer_name]['mean_total'] = np.mean(
                layer_total_change_list)
            lesion_means_dict[layer_name]['mean_max_drop'] = np.mean(
                layer_max_drop_list)
            with open(les_means_dict_name, "wb") as pickle_out:
                pickle.dump(lesion_means_dict, pickle_out)

        # # save layer info
        print(f"\n**** save layer info for {layer_name} ****")

        count_per_cat_df = pd.DataFrame.from_dict(
            count_per_cat_dict[layer_name])
        count_per_cat_df.to_csv(
            f"{output_filename}_{layer_name}_count_per_cat.csv")
        with open(count_p_cat_dict_name, "wb") as pickle_out:
            pickle.dump(count_per_cat_dict, pickle_out)

        prop_change_df = pd.DataFrame.from_dict(prop_change_dict[layer_name])
        prop_change_df.to_csv(
            f"{output_filename}_{layer_name}_prop_change.csv")
        # nick_to_csv(prop_change_df, "{}_{}_prop_change.csv".format(output_filename, layer_name))
        with open(prop_change_dict_name, "wb") as pickle_out:
            pickle.dump(prop_change_dict, pickle_out)

        # # convert item_change_dict to df
        item_change_df = pd.DataFrame.from_dict(item_change_dict[layer_name])
        item_change_df.to_csv(
            f"{output_filename}_{layer_name}_item_change.csv")
        # nick_to_csv(item_change_df, "{}_{}_item_change.csv".format(output_filename, layer_name))
        with open(item_change_dict_name, "wb") as pickle_out:
            pickle.dump(item_change_dict, pickle_out)

        if verbose:
            print(f"\n\ncount_per_cat_df:\n{count_per_cat_df.head()}")
            print(f"\n\nprop_change_df:\n{prop_change_df.head()}")
            print(f"\n\nitem_change_df:\n{item_change_df.head()}")

        # # HIGHLIGHTS dict (for each layer)
        layer_highlights_dict = dict()

        # 1. 3 units with biggest total increase
        total_series = prop_change_df.loc['total', :]
        total_biggest_units = total_series.sort_values(
            ascending=False).head(3).index.to_list()
        total_biggest_vals = total_series.sort_values(
            ascending=False).head(3).to_list()

        # total_increase_dict = {total_biggest_units[i]: total_biggest_vals[i]
        #                        for i in range(sum(1 for x in total_biggest_vals if x > 0))}
        total_increase_dict = {}
        for unit, value in zip(total_biggest_units, total_biggest_vals):
            if unit not in total_increase_dict.keys():
                if value > 0:
                    total_increase_dict[unit] = value
        layer_highlights_dict["total_increase"] = total_increase_dict

        # # check current model highlight values ([1] refers to the value in the tuple)
        # # if the best model highlight is less impressive than layer-highlight, update model highlight dict
        with open(highlights_dict_name, "rb") as pickle_load:
            lesion_highlights_dict = pickle.load(pickle_load)

        if lesion_highlights_dict["highlights"]["total_increase"][
                1] < total_biggest_vals[0]:
            lesion_highlights_dict["highlights"]["total_increase"] = \
                (f'{layer_name}.{total_biggest_units[0]}', total_biggest_vals[0])

        if verbose:
            print(f"\ntotal_increase_dict: {total_increase_dict}")

        # # 2. 3 units with biggest total decrease
        # total_smallest = total_series.nsmallest(n=3, columns=cols)
        total_smallest_units = total_series.sort_values().head(
            3).index.to_list()
        total_smallest_vals = list(total_series.sort_values().head(3))

        # total_decrease_dict = {total_smallest_units[i]: total_smallest_vals[i]
        #                        for i in range(sum(1 for x in total_biggest_vals if x < 0.0))}
        total_decrease_dict = {}
        for unit, value in zip(total_smallest_units, total_smallest_vals):
            if unit not in total_decrease_dict.keys():
                if value < 0:
                    total_decrease_dict[unit] = value
        layer_highlights_dict["total_decrease"] = total_decrease_dict

        # # update model highlights if necessary
        if lesion_highlights_dict["highlights"]["total_decrease"][
                1] > total_smallest_vals[0]:
            lesion_highlights_dict["highlights"]["total_decrease"] = \
                (f'{layer_name}.{total_smallest_units[0]}', total_smallest_vals[0])

        if verbose:
            print(f"\ntotal_decrease_dict: {total_decrease_dict}")

        # # drop the 'totals' column from df so I can just get class scores
        get_class_highlights = prop_change_df.drop('total')

        # biggest class increase
        top3_vals = sorted(set(get_class_highlights.to_numpy().ravel()),
                           reverse=True)[:3]
        top3_tup = []
        for val in top3_vals:
            units = get_class_highlights.columns[get_class_highlights.isin(
                [val]).any()].to_list()
            for idx in units:
                top3_tup.append((idx, val))

        # class_increase_dict = {top3_tup[i][0]: top3_tup[i][1]
        #                        for i in range(sum(1 for x in top3_vals if x > 0.0))}

        # # Note: units = [i[0] for i in top3_tup], values = [i[1] for i in top3_tup]
        class_increase_dict = {}
        for unit, value in zip([i[0] for i in top3_tup],
                               [i[1] for i in top3_tup]):
            if unit not in class_increase_dict.keys():
                if value > 0:
                    class_increase_dict[unit] = value
        layer_highlights_dict["class_increase"] = class_increase_dict

        # # update model highlights if necessary
        if lesion_highlights_dict["highlights"]["class_increase"][
                1] < top3_vals[0]:
            lesion_highlights_dict["highlights"]["class_increase"] = \
                (f'{layer_name}.{top3_tup[0][0]}', top3_vals[0])

        if verbose:
            print(f"\nclass_increase_dict: {class_increase_dict}")

        # biggest class decrease
        bottom3_vals = sorted(set(get_class_highlights.to_numpy().ravel()))[:3]
        bottom3_tup = []
        for val in bottom3_vals:
            units = get_class_highlights.columns[get_class_highlights.isin(
                [val]).any()].to_list()
            for idx in units:
                bottom3_tup.append((idx, val))

        # class_decrease_dict = {bottom3_tup[i][0]: bottom3_tup[i][1]
        #                        for i in range(sum(1 for x in bottom3_vals if x < 0.0))}
        class_decrease_dict = {}
        for unit, value in zip([i[0] for i in bottom3_tup],
                               [i[1] for i in bottom3_tup]):
            if unit not in class_decrease_dict.keys():
                if value < 0:
                    class_decrease_dict[unit] = value
        layer_highlights_dict["class_decrease"] = class_decrease_dict

        # # update model highlights if necessary
        if lesion_highlights_dict["highlights"]["class_decrease"][
                1] > bottom3_vals[0]:
            lesion_highlights_dict["highlights"]["class_decrease"] = \
                (f'{layer_name}.{bottom3_tup[0][0]}', bottom3_vals[0])

        if verbose:
            print(f"\nclass_decrease_dict: {class_decrease_dict}")

        # # save layer highlights to highlights dict
        lesion_highlights_dict[layer_name] = layer_highlights_dict
        with open(highlights_dict_name, "wb") as pickle_out:
            pickle.dump(lesion_highlights_dict, pickle_out)

    # # 6. output:
    print("\n**** make output files and save ****")
    date = int(datetime.datetime.now().strftime("%y%m%d"))
    time = int(datetime.datetime.now().strftime("%H%M"))

    lesion_summary_dict = gha_dict
    lesion_info = {
        "lesion_path": lesion_path,
        "loaded_dict": gha_dict_name,
        "x_data_path": x_data_path,
        "y_data_path": y_data_path,
        "key_layer_classes": get_classes,
        "key_lesion_layers_list": key_lesion_layers_list,
        'total_units_filts': total_units_filts,
        "sel_date": date,
        "sel_time": time,
        'lesion_highlights': lesion_highlights_dict,
        'lesion_means_dict': lesion_means_dict
    }

    lesion_summary_dict["lesion_info"] = lesion_info

    print(f"Saving dict to: {lesion_path}")
    lesion_dict_name = f"{lesion_path}/{output_filename}_lesion_dict.pickle"
    pickle_out = open(lesion_dict_name, "wb")
    pickle.dump(lesion_summary_dict, pickle_out)
    pickle_out.close()

    focussed_dict_print(lesion_summary_dict,
                        'lesion_summary_dict',
                        focus_list=['lesion_info'])
    # print_nested_round_floats(lesion_summary_dict, 'lesion_summary_dict')

    # # lesion summary page
    # lesion_summary_path = '/home/nm13850/Documents/PhD/python_v2/experiments/lesioning/lesion_summary.csv'
    exp_path, cond_name = os.path.split(
        lesion_summary_dict['topic_info']['exp_cond_path'])
    lesion_summary_path = os.path.join(exp_path, 'lesion_summary.csv')

    run = lesion_summary_dict['topic_info']['run']

    if test_run:
        output_filename = f'{output_filename}_test'
        run = 'test'

    ls_info = [
        date, time, output_filename, run,
        lesion_summary_dict['data_info']['dataset'],
        lesion_summary_dict['GHA_info']['use_dataset'],
        lesion_summary_dict['topic_info']['model_path'],
        lesion_highlights_dict['highlights']["total_increase"][0],
        lesion_highlights_dict['highlights']["total_increase"][1],
        lesion_highlights_dict['highlights']["total_decrease"][0],
        lesion_highlights_dict['highlights']["total_decrease"][1],
        lesion_highlights_dict['highlights']["class_increase"][0],
        lesion_highlights_dict['highlights']["class_increase"][1],
        lesion_highlights_dict['highlights']["class_decrease"][0],
        lesion_highlights_dict['highlights']["class_decrease"][1]
    ]

    if not os.path.isfile(lesion_summary_path):
        ls_headers = [
            'date', 'time', 'filename', 'run', 'data', 'dset', 'model',
            "tot_incre_unit", "tot_incre_val", "tot_decre_unit",
            "tot_decre_val", "cat_incre_unit", "cat_incre_val",
            "cat_decre_unit", "cat_decrea_val"
        ]

        print("\ncreating summary csv at: {}".format(lesion_summary_path))
        lesion_summary = pd.DataFrame([ls_info], columns=ls_headers)
        nick_to_csv(lesion_summary, lesion_summary_path)
    else:
        lesion_summary = nick_read_csv(lesion_summary_path)
        ls_cols = list(lesion_summary)  # 14 columns
        for_df = dict(zip(ls_cols, ls_info))
        lesion_summary = lesion_summary.append(for_df, ignore_index=True)
        lesion_summary.to_csv(lesion_summary_path)

    print(f"\nlesion_summary:\n{lesion_summary.tail()}")

    print("\nscript_finished\n\n")

    return lesion_summary_dict
Exemple #5
0
def plot_all_units(sel_dict_path,
                   measure='b_sel',
                   letter_sel=True,
                   correct_items_only=True,
                   just_1st_ts=False,
                   verbose=True,
                   test_run=False,
                   show_plots=False):
    """

    given a cond name

        load dicts and other info (BOTH Letter and word sel dicts)
        choose sel per unit (letter if > 0, else word)

        specify grid shape
        loop thru units, appending to axis

        plot units vertically, timesteps horizontally
        e.g., unit 0: ts0, ts1, ts2, ts3, ts4, ts5, ts6, ts7


    :param sel_dict_path: or gha_dict
    :param measure: selectivity measure to focus on if hl_dict provided
    :param letter_sel: focus on level of words or letters
    :param correct_items_only: remove items that were incorrect
    :param just_1st_ts: just 1st timestep (as in Bowers).if False, plot all timesteps,
    :param verbose:
    :param test_run: just 9 plots
    :param show_plots:

    :return:
    """

    print(f"\n**** running plot_all_units({sel_dict_path}) ****")

    if os.path.isfile(sel_dict_path):
        exp_cond_gha_path, gha_dict_name = os.path.split(sel_dict_path)
        os.chdir(exp_cond_gha_path)
        gha_dict = load_dict(sel_dict_path)

    elif type(sel_dict_path) is dict:
        gha_dict = sel_dict_path
        exp_cond_gha_path = os.getcwd()

    else:
        raise FileNotFoundError(sel_dict_path)

    if verbose:
        focussed_dict_print(gha_dict, 'gha_dict')

    # get topic_info from dict
    output_filename = gha_dict["topic_info"]["output_filename"]
    if letter_sel:
        output_filename = f"{output_filename}_lett"

    # # where to save files
    plots_folder = 'plots'
    cond_name = gha_dict['topic_info']['output_filename']
    condition_path = find_path_to_dir(long_path=exp_cond_gha_path,
                                      target_dir=cond_name)
    plots_path = os.path.join(condition_path, plots_folder)
    if not os.path.exists(plots_path):
        os.makedirs(plots_path)

    if verbose:
        print(f"\noutput_filename: {output_filename}")
        print(f"plots_path (to save): {plots_path}")
        print(f"os.getcwd(): {os.getcwd()}")

    # # get data info from dict
    n_words = gha_dict["data_info"]["n_cats"]
    n_letters = gha_dict["data_info"]["X_size"]
    n_units = gha_dict['model_info']['layers']['hid_layers']['hid_totals'][
        'analysable']

    if verbose:
        print(f"the are {n_words} word classes")

    if letter_sel:
        n_letters = gha_dict['data_info']["X_size"]
        n_words = n_letters
        print(
            f"the are {n_letters} letters classes\nn_words now set as n_letters"
        )

        letter_id_dict = load_dict(
            os.path.join(gha_dict['data_info']['data_path'],
                         'letter_id_dict.txt'))
        print(f"\nletter_id_dict:\n{letter_id_dict}")

    # # get model info from dict
    # model_dict = gha_dict['model_info']['config']
    # if verbose:
    #     focussed_dict_print(model_dict, 'model_dict')

    timesteps = gha_dict['model_info']["overview"]["timesteps"]
    vocab_dict = load_dict(
        os.path.join(gha_dict['data_info']["data_path"],
                     gha_dict['data_info']["vocab_dict"]))

    # '''Part 2 - load y, sort out incorrect resonses'''
    # print("\n\nPart 2: loading labels")
    # # # load y_labels to go with hid_acts and item_correct for sequences
    # if 'seq_corr_list' in gha_dict['GHA_info']['scores_dict']:
    #     n_seqs = gha_dict['GHA_info']['scores_dict']['n_seqs']
    #     n_seq_corr = gha_dict['GHA_info']['scores_dict']['n_seq_corr']
    #     n_incorrect = n_seqs - n_seq_corr
    #
    #     test_label_seq_name = gha_dict['GHA_info']['y_data_path']
    #     seqs_corr = gha_dict['GHA_info']['scores_dict']['seq_corr_list']
    #
    #     test_label_seqs = np.load(f"{test_label_seq_name}labels.npy")
    #
    #     if verbose:
    #         print(f"test_label_seqs: {np.shape(test_label_seqs)}")
    #         print(f"seqs_corr: {np.shape(seqs_corr)}")
    #         print(f"n_seq_corr: {n_seq_corr}")
    #
    #     if letter_sel:
    #         # # get 1hot item vectors for 'words' and 3 hot for letters
    #         '''Always use serial_recall True. as I want a separate 1hot vector for each item.
    #         Always use x_data_type 'local_letter_X' as I want 3hot vectors'''
    #         y_letters = []
    #         y_words = []
    #         for this_seq in test_label_seqs:
    #             get_letters, get_words = get_X_and_Y_data_from_seq(vocab_dict=vocab_dict,
    #                                                                seq_line=this_seq,
    #                                                                serial_recall=True,
    #                                                                end_seq_cue=False,
    #                                                                x_data_type='local_letter_X')
    #             y_letters.append(get_letters)
    #             y_words.append(get_words)
    #
    #         y_letters = np.array(y_letters)
    #         y_words = np.array(y_words)
    #         if verbose:
    #             print(f"\ny_letters: {type(y_letters)}  {np.shape(y_letters)}")
    #             print(f"y_words: {type(y_words)}  {np.shape(y_words)}")
    #
    #     y_df_headers = [f"ts{i}" for i in range(timesteps)]
    #     y_scores_df = pd.DataFrame(data=test_label_seqs, columns=y_df_headers)
    #     y_scores_df['full_model'] = seqs_corr
    #     if verbose:
    #         print(f"\ny_scores_df: {y_scores_df.shape}\n{y_scores_df.head()}")
    #
    #
    # # # if not sequence data, load y_labels to go with hid_acts and item_correct for items
    # elif 'item_correct_name' in gha_dict['GHA_info']['scores_dict']:
    #     # # load item_correct (y_data)
    #     item_correct_name = gha_dict['GHA_info']['scores_dict']['item_correct_name']
    #     # y_df = pd.read_csv(item_correct_name)
    #     y_scores_df = nick_read_csv(item_correct_name)
    #
    # """# # get rid of incorrect items if required"""
    # print("\n\nRemoving incorrect responses")
    # # # # get values for correct/incorrect items (1/0 or True/False)
    # item_correct_list = y_scores_df['full_model'].tolist()
    # full_model_values = list(set(item_correct_list))
    #
    # correct_symbol = 1
    # if len(full_model_values) != 2:
    #     TypeError(f"TYPE_ERROR!: what are the scores/acc for items? {full_model_values}")
    # if 1 not in full_model_values:
    #     if True in full_model_values:
    #         correct_symbol = True
    #     else:
    #         TypeError(f"TYPE_ERROR!: what are the scores/acc for items? {full_model_values}")
    #
    # print(f"len(full_model_values): {len(full_model_values)}")
    # print(f"correct_symbol: {correct_symbol}")
    #
    # # # i need to check whether this analysis should include incorrect items (True/False)
    # gha_incorrect = gha_dict['GHA_info']['gha_incorrect']
    #
    # # get item indeces for correct and incorrect items
    # item_index = list(range(n_seq_corr))
    #
    # incorrect_items = []
    # correct_items = []
    # for index in range(len(item_correct_list)):
    #     if item_correct_list[index] == 0:
    #         incorrect_items.append(index)
    #     else:
    #         correct_items.append(index)
    # if correct_items_only:
    #     item_index == correct_items
    #
    # if gha_incorrect:
    #     if correct_items_only:
    #         if verbose:
    #             print("\ngha_incorrect: True (I have incorrect responses)\n"
    #                   "correct_items_only: True (I only want correct responses)")
    #             print(f"remove {n_incorrect} incorrect from hid_acts & output using y_scores_df.")
    #             print("use y_correct for y_df")
    #
    #         y_correct_df = y_scores_df.loc[y_scores_df['full_model'] == correct_symbol]
    #         y_df = y_correct_df
    #
    #         mask = np.ones(shape=len(seqs_corr), dtype=bool)
    #         mask[incorrect_items] = False
    #         test_label_seqs = test_label_seqs[mask]
    #
    #         if letter_sel:
    #             y_letters = y_letters[mask]
    #
    #     else:
    #         if verbose:
    #             print("\ngha_incorrect: True (I have incorrect responses)\n"
    #                   "correct_items_only: False (I want incorrect responses)")
    #             print("no changes needed - don't remove anything from hid_acts, output and "
    #                   "use y scores as y_df")
    # else:
    #     if correct_items_only:
    #         if verbose:
    #             print("\ngha_incorrect: False (I only have correct responses)\n"
    #                   "correct_items_only: True (I only want correct responses)")
    #             print("no changes needed - don't remove anything from hid_acts or output.  "
    #                   "Use y_correct as y_df")
    #         y_correct_df = y_scores_df.loc[y_scores_df['full_model'] == correct_symbol]
    #         y_df = y_correct_df
    #     else:
    #         if verbose:
    #             print("\ngha_incorrect: False (I only have correct responses)\n"
    #                   "correct_items_only: False (I want incorrect responses)")
    #             raise TypeError("I can not complete this as desried"
    #                             "change correct_items_only to True"
    #                             "for analysis  - don't remove anything from hid_acts, output and "
    #                             "use y scores as y_df")
    #
    #         # correct_items_only = True
    #
    # if verbose is True:
    #     print(f"\ny_df: {y_df.shape}\n{y_df.head()}")
    #     print(f"\ntest_label_seqs: {np.shape(test_label_seqs)}")  # \n{test_label_seqs}")
    #     # if letter_sel:
    #     #     y_letters = np.asarray(y_letters)
    #     #     print(f"y_letters: {np.shape(y_letters)}")  # \n{test_label_seqs}")
    #
    #
    # # # load test seqs
    # n_correct, timesteps = np.shape(test_label_seqs)
    # corr_test_seq_name = f"{output_filename}_{n_correct}_corr_test_label_seqs.npy"
    # np.save(corr_test_seq_name, test_label_seqs)
    # corr_test_letters_name = 'not_processed_yet'
    # if letter_sel:
    #     corr_test_letters_name = f"{output_filename}_{n_correct}_corr_test_letter_seqs.npy"
    #     np.save(corr_test_letters_name, y_letters)
    #
    #
    # # # get items per class
    # IPC_dict = seq_items_per_class(label_seqs=test_label_seqs, vocab_dict=vocab_dict)
    # focussed_dict_print(IPC_dict, 'IPC_dict')
    # corr_test_IPC_name = f"{output_filename}_{n_correct}_corr_test_IPC.pickle"
    # with open(corr_test_IPC_name, "wb") as pickle_out:
    #     pickle.dump(IPC_dict, pickle_out, protocol=pickle.HIGHEST_PROTOCOL)
    #
    # # # how many times is each item represented at each timestep.
    # word_p_class_p_ts = IPC_dict['word_p_class_p_ts']
    # letter_p_class_p_ts = IPC_dict['letter_p_class_p_ts']
    #
    # for i in range(timesteps):
    #     n_words_p_ts = len(word_p_class_p_ts[f"ts{i}"].keys())
    #     n_letters_p_ts = len(letter_p_class_p_ts[f"ts{i}"].keys())
    #
    #     print(f"ts{i}) words:{n_words_p_ts}/{n_words}\tletters: {n_letters_p_ts}/{n_letters}")
    #     # print(word_p_class_p_ts[f"ts{i}"].keys())

    # get max sel per unit for words or letters
    combo_dict = word_letter_combo_dict(sel_dict_path, measure=measure)
    focussed_dict_print(combo_dict, 'combo_dict')
    '''save results
    either make a new empty place to save.
    or load previous version and get the units I have already completed'''
    os.chdir(plots_path)

    # # arrangement of subplots
    print("\narrangement of subplots")
    # # get max plots per page
    max_rows = 20
    max_cols = 10
    test_run_value = 9

    # # get required number of plots
    total_plots = n_units * timesteps
    n_cols = timesteps
    if just_1st_ts:
        total_plots = n_units
        n_cols = max_cols

    max_page_plots = max_rows * n_cols

    if test_run:
        total_plots = test_run_value

    total_rows = -(-total_plots // n_cols)  # double negation rounds up.
    n_pages = -(-total_plots // max_page_plots)

    last_page_plots = max_page_plots
    last_page_rows = max_rows
    if total_plots % max_page_plots != 0:
        last_page_plots = total_plots % max_page_plots
        last_page_rows = total_rows % max_rows

    print(
        f'\nn_units: {n_units}, timesteps: {timesteps}, just_1st_ts: {just_1st_ts}\n'
        f'n_cols: {n_cols}, total_rows: {total_rows}, max_page_plots: {max_page_plots}\n'
        f'total_plots: {total_plots}, n_pages: {n_pages}, last_page_plots: {last_page_plots}\n'
    )
    '''
    part 3   - get gha for each unit
    '''
    loop_gha = loop_thru_acts(gha_dict_path=sel_dict_path,
                              correct_items_only=correct_items_only,
                              letter_sel=letter_sel,
                              verbose=verbose,
                              test_run=test_run)

    for page in range(n_pages):
        page_num = page + 1
        # # get number of plots on this page
        page_n_plots = max_page_plots
        page_n_rows = max_rows
        if page_num == n_pages:
            page_n_plots = last_page_plots
            page_n_rows = last_page_rows

        page_start = page * page_n_plots
        page_ends = page_start + page_n_plots

        print(f"\n\nNEW PAGE\t\tpage: {page_num} of {n_pages}.\n"
              f"Plots {page_start} - {page_ends} of {total_plots}\n"
              f"page_n_plots: {page_n_plots}, page_n_rows: {page_n_rows}")

        fig, axes = plt.subplots(nrows=page_n_rows,
                                 ncols=n_cols,
                                 sharex=True,
                                 sharey=True,
                                 constrained_layout=True)  # , squeeze=True)

        fig_height = page_n_rows / 2
        fig_width = 5  # n_cols / 2
        print(f'fig_height: {fig_height}, fig_width: {fig_width}')

        fig.set_size_inches(fig_width, fig_height, forward=True)

        axes_zip = list(zip(range(1, page_n_plots + 1), axes.flatten()))

        # # fig title
        fig_title = f'{cond_name}\nAll units & timesteps'
        if just_1st_ts:
            fig_title = f'{cond_name}\nAll units, first timesteps'
        if n_pages > 1:
            fig_title = f'{fig_title} {page_num} of {n_pages}'

        fig.suptitle(fig_title)

        plot_counter = 0

        # # note, iter_idx restarts from zero for each page,
        # # whilst unit_gha continues from where it left off
        for iter_idx, unit_gha in enumerate(loop_gha):

            if just_1st_ts:
                if unit_gha["timestep"] != 0:
                    continue
            # else:
            # stop printing after one page.
            # if iter_idx + 1 > page_n_plots:
            #     continue

            print(f"\nnew-subplot: iter_idx: {iter_idx}")

            layer_name = unit_gha["layer_name"]
            unit_index = unit_gha["unit_index"]
            timestep = unit_gha["timestep"]
            ts_name = f"ts{timestep}"
            item_act_label_array = unit_gha["item_act_label_array"]

            print(f"unit_index: {unit_index}, ts_name: {ts_name}")

            # focussed_dict_print(unit_gha, 'unit_gha')

            # #  make df
            this_unit_acts = pd.DataFrame(
                data=item_act_label_array,
                columns=['item', 'activation', 'label'])
            this_unit_acts_df = this_unit_acts.astype({
                'item': 'int32',
                'activation': 'float',
                'label': 'int32'
            })

            # # where to put plot
            if just_1st_ts:
                ax_idx = axes_zip[unit_index][0]
                ax = axes_zip[unit_index][1]
            else:
                ax_idx = axes_zip[iter_idx][0]
                ax = axes_zip[iter_idx][1]

            print(f"ax_idx: {ax_idx}: ax: {ax}"
                  f"\nsubplot: row{unit_index} col{timestep} ")

            # # for this unit - get sel stats from combo dict
            ts_dict = combo_dict[layer_name][unit_index][ts_name]
            sel_level = ts_dict['level']
            sel_value = round(ts_dict['sel'], 3)
            sel_feat = ts_dict['feat']
            print(
                f"sel_level: {sel_level}, sel_value: {sel_value}, sel_feat: {sel_feat}"
            )

            # # get sel_feat
            # # selective_for_what
            sel_idx = sel_feat
            if sel_level == 'letter':
                sel_item = letter_id_dict[sel_feat]
            else:
                sel_item = vocab_dict[sel_feat]['word']

            if sel_level == 'letter':
                label_list = this_unit_acts_df['label'].to_list()
                sel_item_list = letter_in_seq(letter=sel_feat,
                                              test_label_seqs=label_list,
                                              vocab_dict=vocab_dict)
                # print(f"\n\nsel_item_list: {np.shape(sel_item_list)}\n{sel_item_list}")

                # y_letters_1ts = np.array(y_letters[:, timestep])
                # print(f"y_letters_1ts: {np.shape(y_letters_1ts)}")
                # use this to just get a binary array of whether a letter is present?
                # sel_item_list = y_letters_1ts[:, sel_idx]
            else:
                # # sort class label list
                class_labels = this_unit_acts['label'].to_list()

                # sel_item_list = [1 if x == sel_item else 0 for x in seq_words_list]
                sel_item_list = [
                    1 if x == sel_feat else 0 for x in class_labels
                ]

            this_unit_acts_df['sel_item'] = sel_item_list
            # print(f"this_unit_acts_df:\n{this_unit_acts_df}")

            sns.catplot(
                x='activation',
                y="label",
                hue='sel_item',
                data=this_unit_acts_df,
                ax=ax,
                orient='h',
                kind="strip",
                jitter=1,
                dodge=True,
                linewidth=.5,
                s=3,

                # palette="Set2", marker="D", edgecolor="gray"
            )  # , alpha=.25)

            ax.set_xlim([0, 1])
            # ax.set_ylim([-1, n_words+1])
            # print(f"y_lim: {ax.get_ylim()}")
            ax.margins(y=.05)
            ax.set_yticks([])
            # ax.set_title(f'u{unit_index}-{timestep}\n{sel_item}: {sel_value}', fontsize=8)
            ax.get_legend().set_visible(False)

            ax.set(xlabel='', ylabel='')

            # sort labels for left and bottom plots
            if just_1st_ts:
                if unit_index % max_cols == 0:
                    ax.set_ylabel(f"U{unit_index}-{unit_index+max_cols-1}",
                                  rotation='horizontal',
                                  ha='right')
            else:
                if iter_idx % timesteps == 0:
                    ax.set_ylabel(f"U {unit_index}",
                                  rotation='horizontal',
                                  ha='right')
                if iter_idx >= page_n_plots - timesteps:
                    ax.set_xlabel(f"{ts_name}")

            plt.close()

            # # stop if done enough plots
            plot_counter += 1
            print(f'plot_counter: {plot_counter}')

            if test_run:
                print(f'\nEnd of {test_run_value} test_run plots')
                if plot_counter == test_run_value:
                    break

            if plot_counter == max_page_plots:
                print(f'\nEnd of page {page_num} of {n_pages}\n')
                break

        # # once broken out of plots loop (e.g., at end of page)
        # # save name
        if just_1st_ts:
            save_name = f"{plots_path}/{output_filename}_all_U_1ts.png"
        else:
            save_name = f"{plots_path}/{output_filename}_all plots.png"
        if n_pages > 1:
            save_name = f'{save_name[:-4]}_{page_num}of{n_pages}.png'
        plt.savefig(save_name)

        if show_plots:
            plt.show()

    print("\nend of plot_all_units script")
def raincloud_w_fail(sel_dict_path, lesion_dict_path, plot_type='classes', coi_measure='c_informed', top_layers='all',
                     selected_units=False,
                     plots_dir='simple_rain_plots',
                     plot_fails=False,
                     plot_class_change=False,
                     normed_acts=False,
                     layer_act_dist=False,
                     verbose=False, test_run=False,
                     ):
    """
    With visualise units with raincloud plot.  has distributions (cloud), individual activations (raindrops), boxplot
     to give median and interquartile range.  Also has plot of zero activations, scaled by class size.  Will show items
     that are affected by lesioning in different colours.

     I only have lesion data for [conv2d, dense] layers
     I have GHA and sel data from  [conv2d, activation, max_pooling2d, dense] layers

     so for each lesioned layer [conv2d, dense] I will use the following activation layer to take GHA and sel data from.

     Join these into groups using the activation number as the layer numbers.
     e.g., layer 1 (first conv layer) = conv2d_1 & activation_1.  layer 7 (first fc layer) = dense1 & activation 7)

    :param sel_dict_path:  path to selectivity dict
    :param lesion_dict_path: path to lesion dict
    :param plot_type: all classes or OneVsAll.  if n_cats > 10, should automatically revert to oneVsAll.
    :param coi_measure: measure to use when choosing which class should be the coi.  Either the best performing sel
            measures (c_informed, c_ROC) or max class drop from lesioning.
    :param top_layers: if int, it will just do the top n mayers (excluding output).  If not int, will do all layers.
    :param selected_units: default is to test all units on all layers.  But If I just want individual units, I should be
                    able to input a dict with layer names as keys and a list for each unit on that layer.
                    e.g., to just get unit 216 from 'fc_1' use selected_units={'fc_1': [216]}.
    :param plots_dir: where to save plots
    :param plot_fails: If False, just plots correct items, if true, plots items that failed after lesioning in RED
    :param plot_class_change: if True, plots proportion of items correct per class.
    :param normed_acts: if False use actual activation values, if True, normalize activations 0-1
    :param layer_act_dist: plot the distribution of all activations on a given layer.
                                This should already have been done in GHA
    :param verbose: how much to print to screen
    :param test_run: if True, just plot two units from two layers, if False, plot all (or selected units)

    returns nothings, just saves the plots
    """

    print("\n**** running visualise_units()****")

    if not selected_units:
        print(f"selected_units?: {selected_units}\n"
              "running ALL layers and units")
    else:
        print(focussed_dict_print(selected_units, 'selected_units'))
    # if type(selected_units) is dict:
    #     print("dict found")

    # # lesion dict
    lesion_dict = load_dict(lesion_dict_path)
    focussed_dict_print(lesion_dict, 'lesion_dict')

    # # get key_lesion_layers_list
    lesion_info = lesion_dict['lesion_info']
    lesion_path = lesion_info['lesion_path']
    lesion_highlighs = lesion_info["lesion_highlights"]
    key_lesion_layers_list = list(lesion_highlighs.keys())

    # # remove unnecesary items from key layers list
    if 'highlights' in key_lesion_layers_list:
        key_lesion_layers_list.remove('highlights')
    # if 'output' in key_lesion_layers_list:
    #     key_lesion_layers_list.remove('output')
    # if 'Output' in key_lesion_layers_list:
    #     key_lesion_layers_list.remove('Output')

    # # remove output layers from key layers list
    if any("utput" in s for s in key_lesion_layers_list):
        output_layers = [s for s in key_lesion_layers_list if "utput" in s]
        output_idx = []
        for out_layer in output_layers:
            output_idx.append(key_lesion_layers_list.index(out_layer))
        min_out_idx = min(output_idx)
        key_lesion_layers_list = key_lesion_layers_list[:min_out_idx]

    # # remove output layers from key layers list
    if any("predictions" in s for s in key_lesion_layers_list):
        output_layers = [s for s in key_lesion_layers_list if "predictions" in s]
        output_idx = []
        for out_layer in output_layers:
            output_idx.append(key_lesion_layers_list.index(out_layer))
        min_out_idx = min(output_idx)
        key_lesion_layers_list = key_lesion_layers_list[:min_out_idx]

    class_labels = list(lesion_dict['data_info']['cat_names'].values())

    # # sel_dict
    sel_dict = load_dict(sel_dict_path)
    if key_lesion_layers_list[0] in sel_dict['sel_info']:
        print('\nfound old sel dict layout')
        key_gha_sel_layers_list = list(sel_dict['sel_info'].keys())
        old_sel_dict = True
        # sel_info = sel_dict['sel_info']
        # short_sel_measures_list = list(sel_info[key_lesion_layers_list[0]][0]['sel'].keys())
        # csb_list = list(sel_info[key_lesion_layers_list[0]][0]['class_sel_basics'].keys())
        # sel_measures_list = short_sel_measures_list + csb_list
    else:
        print('\nfound NEW sel dict layout')
        old_sel_dict = False
        sel_info = load_dict(sel_dict['sel_info']['sel_per_unit_pickle_name'])
        # sel_measures_list = list(sel_info[key_lesion_layers_list[0]][0].keys())
        key_gha_sel_layers_list = list(sel_info.keys())
        # print(sel_info.keys())

    # # get key_gha_sel_layers_list
    # # # remove unnecesary items from key layers list
    # if 'sel_analysis_info' in key_gha_sel_layers_list:
    #     key_gha_sel_layers_list.remove('sel_analysis_info')
    # if 'output' in key_gha_sel_layers_list:
    #     output_idx = key_gha_sel_layers_list.index('output')
    #     key_gha_sel_layers_list = key_gha_sel_layers_list[:output_idx]
    # if 'Output' in key_gha_sel_layers_list:
    #     output_idx = key_gha_sel_layers_list.index('Output')
    #     key_gha_sel_layers_list = key_gha_sel_layers_list[:output_idx]

    # # remove output layers from key layers list
    if any("utput" in s for s in key_gha_sel_layers_list):
        output_layers = [s for s in key_gha_sel_layers_list if "utput" in s]
        output_idx = []
        for out_layer in output_layers:
            output_idx.append(key_gha_sel_layers_list.index(out_layer))
        min_out_idx = min(output_idx)
        key_gha_sel_layers_list = key_gha_sel_layers_list[:min_out_idx]
        # key_layers_df = key_layers_df.loc[~key_layers_df['name'].isin(output_layers)]

    # # remove output layers from key layers list
    if any("predictions" in s for s in key_lesion_layers_list):
        output_layers = [s for s in key_lesion_layers_list if "predictions" in s]
        output_idx = []
        for out_layer in output_layers:
            output_idx.append(key_lesion_layers_list.index(out_layer))
        min_out_idx = min(output_idx)
        key_lesion_layers_list = key_lesion_layers_list[:min_out_idx]

    # # put together lists of 1. sel_gha_layers, 2. key_lesion_layers_list.
    n_activation_layers = sum("activation" in layers for layers in key_gha_sel_layers_list)
    n_lesion_layers = len(key_lesion_layers_list)

    if n_activation_layers == n_lesion_layers:
        # # for models where activation and conv (or dense) are separate layers
        n_layers = n_activation_layers
        activation_layers = [layers for layers in key_gha_sel_layers_list if "activation" in layers]
        link_layers_dict = dict(zip(reversed(activation_layers), reversed(key_lesion_layers_list)))

    elif n_activation_layers == 0:
        print("\nno separate activation layers found - use key_lesion_layers_list")
        n_layers = len(key_lesion_layers_list)
        link_layers_dict = dict(zip(reversed(key_lesion_layers_list), reversed(key_lesion_layers_list)))

    else:
        print(f"n_activation_layers: {n_activation_layers}\n{key_gha_sel_layers_list}")
        print("n_lesion_layers: {n_lesion_layers}\n{key_lesion_layers_list}")
        raise TypeError('should be same number of activation layers and lesioned layers')

    if verbose is True:
        focussed_dict_print(link_layers_dict, 'link_layers_dict')

    # # # get info
    exp_cond_path = sel_dict['topic_info']['exp_cond_path']
    output_filename = sel_dict['topic_info']['output_filename']

    # hid acts hdf
    hdf_name = f'{output_filename}_gha.h5'


    # # load data
    # # check for training data
    use_dataset = sel_dict['GHA_info']['use_dataset']

    n_cats = sel_dict['data_info']["n_cats"]

    if use_dataset in sel_dict['data_info']:
        # n_items = sel_dict["data_info"][use_dataset]["n_items"]
        items_per_cat = sel_dict["data_info"][use_dataset]["items_per_cat"]
    else:
        # n_items = sel_dict["data_info"]["n_items"]
        items_per_cat = sel_dict["data_info"]["items_per_cat"]
    if type(items_per_cat) is int:
        items_per_cat = dict(zip(list(range(n_cats)), [items_per_cat] * n_cats))

    if plot_type != 'OneVsAll':
        if n_cats > 20:
            plot_type = 'OneVsAll'
            print("\n\n\nWARNING!  There are lots of classes, it might make a messy plot"
                  "Switching to OneVsAll\n")

    if sel_dict['GHA_info']['gha_incorrect'] == 'False':
        # # only gha for correct items
        # n_items = sel_dict['GHA_info']['scores_dict']['n_correct']
        items_per_cat = sel_dict['GHA_info']['scores_dict']['corr_per_cat_dict']

    # # load hid acts dict called hid_acts.pickle
    """
    Hid_acts dict has numbers as the keys for each layer.
    Some layers (will be missing) as acts only recorded from some layers (e.g., [17, 19, 20, 22, 25, 26, 29, 30])
    hid_acts_dict.keys(): dict_keys([0, 1, 3, 5, 6, 8, 9, 11, 13, 14, 16, 17, 19, 20, 22, 25, 26, 29, 30])
    hid_acts_dict[0].keys(): dict_keys(['layer_name', 'layer_class', 'layer_shape', '2d_acts', 'converted_to_2d'])
    In each layer there is ['layer_name', 'layer_class', 'layer_shape', '2d_acts']
    For 4d layers (conv, pool) there is also, key, value 'converted_to_2d': True
    """

    # # check if I have saved the location to this file
    hid_acts_pickle_name = sel_dict["GHA_info"]["hid_act_files"]['2d']
    if 'gha_path' in sel_dict['GHA_info']:
        gha_path = sel_dict['GHA_info']['gha_path']
        hid_acts_path = os.path.join(gha_path, hid_acts_pickle_name)
    else:
        hid_act_items = 'all'
        if not sel_dict['GHA_info']['gha_incorrect']:
            hid_act_items = 'correct'

        gha_folder = f'{hid_act_items}_{use_dataset}_gha'
        hid_acts_path = os.path.join(exp_cond_path, gha_folder, hid_acts_pickle_name)
    with open(hid_acts_path, 'rb') as pkl:
        hid_acts_dict = pickle.load(pkl)
    print("\nopened hid_acts.pickle")

    # # # visualizing distribution of activations
    # if layer_act_dist:
    #     print("\nPlotting the distributions of activations for each layer")
    #     for k, v in hid_acts_dict.items():
    #         print("\nPlotting distribution of layer acts")
    #         layer_act_dist_dir = 'layer_act_dist'
    #         print(hid_acts_dict[k]['layer_name'])
    #         hid_acts = hid_acts_dict[k]['2d_acts']
    #         print(np.shape(hid_acts))
    #         sns.distplot(np.ravel(hid_acts))
    #         plt.title(str(hid_acts_dict[k]['layer_name']))
    #         dist_plot_name = "{}_{}_layer_act_distplot.png".format(output_filename, hid_acts_dict[k]['layer_name'])
    #         plt.savefig(os.path.join(plots_dir, layer_act_dist_dir, dist_plot_name))
    #         # plt.show()
    #         plt.close()

    # # dict to get the hid_acts_dict key for each layer based on its name
    get_hid_acts_number_dict = dict()
    for key, value in hid_acts_dict.items():
        hid_acts_layer_name = value['layer_name']
        hid_acts_layer_number = key
        get_hid_acts_number_dict[hid_acts_layer_name] = hid_acts_layer_number

    # # where to save files
    save_plots_name = plots_dir
    if plot_type is "OneVsAll":
        save_plots_name = f'{plots_dir}/{coi_measure}'
    save_plots_dir = lesion_dict['GHA_info']['gha_path']
    save_plots_path = os.path.join(save_plots_dir, save_plots_name)
    if test_run:
        save_plots_path = os.path.join(save_plots_path, 'test')
    if not os.path.exists(save_plots_path):
        os.makedirs(save_plots_path)
    os.chdir(save_plots_path)
    print(f"\ncurrent wd: {os.getcwd()}")

    if layer_act_dist:
        layer_act_dist_path = os.path.join(save_plots_path, 'layer_act_dist')
        if not os.path.exists(layer_act_dist_path):
            os.makedirs(layer_act_dist_path)


    print("\n\n**********************"
          "\nlooping through layers"
          "\n**********************\n")

    for layer_index, (gha_layer_name, lesion_layer_name) in enumerate(link_layers_dict.items()):

        if test_run:
            if layer_index > 2:
                continue

        if type(top_layers) is int:
            if top_layers < n_activation_layers:
                if layer_index > top_layers:
                    continue


        # print(f"\nwhich units?: {selected_units}")
        # if selected_units != 'all':
        if selected_units is not False:
            if gha_layer_name not in selected_units:
                print(f"\nselected_units only, skipping layer {gha_layer_name}")
                continue
            else:
                print(f"\nselected_units only, from {gha_layer_name}")
                # print(f"\t{gha_layer_name} in {list(selected_units.keys())}")
                this_layer_units = selected_units[gha_layer_name]
                print(f"\trunning units: {this_layer_units}")

        gha_layer_number = get_hid_acts_number_dict[gha_layer_name]
        layer_dict = hid_acts_dict[gha_layer_number]

        if gha_layer_name != layer_dict['layer_name']:
            raise TypeError("gha_layer_name (from link_layers_dict) and layer_dict['layer_name'] should match! ")

        # hid_acts_array = layer_dict['2d_acts']
        # hid_acts_df = pd.DataFrame(hid_acts_array, dtype=float)

        with h5py.File(hdf_name, 'r') as gha_data:
            hid_acts_array = gha_data['hid_acts_2d'][gha_layer_name]
            hid_acts_df = pd.DataFrame(hid_acts_array)

        # # visualizing distribution of activations
        if layer_act_dist:
            hid_acts = layer_dict['2d_acts']
            print(f"\nPlotting distribution of activations {np.shape(hid_acts)}")
            sns.distplot(np.ravel(hid_acts))
            plt.title(f"{str(layer_dict['layer_name'])} activation distribution")
            dist_plot_name = "{}_{}_layer_act_distplot.png".format(output_filename, layer_dict['layer_name'])
            plt.savefig(os.path.join(layer_act_dist_path, dist_plot_name))
            if test_run:
                plt.show()
            plt.close()


        # # load item change details
        """# # four possible states
            full model      after_lesion    code
        1.  1 (correct)     0 (wrong)       -1
        2.  0 (wrong)       0 (wrong)       0
        3.  1 (correct)     1 (correct)     1
        4.  0 (wrong)       1 (correct)     2

        """
        item_change_df = pd.read_csv(f"{lesion_path}/{output_filename}_{lesion_layer_name}_item_change.csv",
                                     header=0, dtype=int, index_col=0)

        prop_change_df = pd.read_csv(f'{lesion_path}/{output_filename}_{lesion_layer_name}_prop_change.csv',
                                     header=0,
                                     # dtype=float,
                                     index_col=0)

        if verbose:
            print("\n*******************************************"
                  f"\n{layer_index}. gha layer {gha_layer_number}: {gha_layer_name} \tlesion layer: {lesion_layer_name}"
                  "\n*******************************************")
            # focussed_dict_print(hid_acts_dict[layer_index])
            print(f"\n\thid_acts {gha_layer_name} shape: {hid_acts_df.shape}")
            print(f"\tloaded: {output_filename}_{lesion_layer_name}_item_change.csv: {item_change_df.shape}")

        units_per_layer = len(hid_acts_df.columns)

        print("\n\n\t**** loop through units ****")
        for unit_index, unit in enumerate(hid_acts_df.columns):

            if test_run:
                if unit_index > 2:
                    continue

            # if selected_units != 'all':
            if selected_units is not False:
                if unit not in this_layer_units:
                    # print(f"skipping unit {gha_layer_name} {unit}")
                    continue
                else:
                    print(f"\nrunning unit {gha_layer_name} {unit}")

            # # check unit is in sel_per_unit_dict
            if unit in sel_info[gha_layer_name].keys():
                if verbose:
                    print("found unit in dict")
            else:
                print("unit not in dict\n!!!!!DEAD RELU!!!!!!!!\n...on to the next unit\n")
                continue

            lesion_layer_and_unit = f"{lesion_layer_name}_{unit}"
            output_layer_and_unit = f"{lesion_layer_name}_{unit}"


            print("\n\n*************\n"
                  f"running layer {layer_index} of {n_layers} ({gha_layer_name}): unit {unit} of {units_per_layer}\n"
                  "************")

            # # make new df with just [item, hid_acts*, class, item_change*] *for this unit
            unit_df = item_change_df[["item", "class", lesion_layer_and_unit]].copy()
            # print(hid_acts_df)
            this_unit_hid_acts = hid_acts_df.loc[:, unit]


            # # check for dead relus
            if sum(np.ravel(this_unit_hid_acts)) == 0.0:
                print("\n\n!!!!!DEAD RELU!!!!!!!!...on to the next unit\n")
                continue

            if verbose:
                print(f"\tnot a dead unit, hid acts sum: {sum(np.ravel(this_unit_hid_acts)):.2f}")

            unit_df.insert(loc=1, column='hid_acts', value=this_unit_hid_acts)
            unit_df = unit_df.rename(index=str, columns={lesion_layer_and_unit: 'item_change'})

            if verbose is True:
                print(f"\n\tall items - unit_df: {unit_df.shape}")

            # # remove rows where network failed originally and after lesioning this unit - uninteresting
            old_df_length = len(unit_df)
            unit_df = unit_df.loc[unit_df['item_change'] != 0]
            if verbose is True:
                n_fail_fail = old_df_length - len(unit_df)
                print(f"\n\t{n_fail_fail} fail-fail items removed - new shape unit_df: {unit_df.shape}")

            # # get items per class based on their occurences in the dataframe.
            # # this includes fail-pass, pass-pass and pass-fail - but not fail-fail
            no_fail_fail_ipc = unit_df['class'].value_counts(sort=False)

            df_ipc = dict()
            for i in range(n_cats):
                df_ipc[i] = no_fail_fail_ipc[i]

            # # # calculate the proportion of items that failed.
            # # # this is not the same as total_unit_change (which takes into account fail-pass as well as pass-fail)
            # df_ipc_total = sum(df_ipc.values())
            # l_failed_df = unit_df[(unit_df['item_change'] == -1)]
            # l_failed_count = len(l_failed_df)
            #
            # print("\tdf_ipc_total: {}".format(df_ipc_total))
            # print("\tl_failed_count: {}".format(l_failed_count))

            # # getting max_class_drop
            max_class_drop_col = prop_change_df.loc[:, str(unit)]
            total_unit_change = max_class_drop_col['total']
            max_class_drop_col = max_class_drop_col.drop(labels=['total'])
            max_class_drop_val = max_class_drop_col.min()
            max_drop_class = max_class_drop_col.idxmin()
            print(f"\n\tmax_class_drop_val: {max_class_drop_val}\n"
                  f"\tmax_drop_class: {max_drop_class}\n"
                  f"\ttotal_unit_change: {total_unit_change}")

            # # getting best sel measure (max_informed)
            main_sel_name = 'informedness'

            # # includes if statement since some units have not score (dead relu?)
            if old_sel_dict:
                main_sel_val = sel_dict['sel_info'][gha_layer_name][unit]['max']['informed']
                main_sel_class = int(sel_dict['sel_info'][gha_layer_name][unit]['max']['c_informed'])
            else:
                # print(sel_info[gha_layer_name][unit]['max'])
                main_sel_val = sel_info[gha_layer_name][unit]['max']['max_informed']
                main_sel_class = int(sel_info[gha_layer_name][unit]['max']['max_informed_c'])

            print(f"\tmain_sel_val: {main_sel_val}")
            print(f"\tmain_sel_class: {main_sel_class}")

            # # coi stands for Class Of Interest
            # # if doing oneVsAll I need to have a coi measure. (e.g., clas with max informed 'c_informed')
            if plot_type is "OneVsAll":

                # # get coi
                if coi_measure == 'max_class_drop':
                    coi = max_drop_class
                elif coi_measure == 'c_informed':
                    coi = main_sel_class
                else:
                    coi = int(sel_dict['sel_info'][gha_layer_name][unit]['max'][coi_measure])
                print(f"\n\tcoi: {coi}  ({coi_measure})")

                # # get new class labels based on coi, OneVsAll
                all_classes_col = unit_df['class'].astype(int)

                one_v_all_class_list = [1 if x is coi else 0 for x in all_classes_col]
                print(f"\tall_classes_col: {len(all_classes_col)}  one_v_all_class_list: {len(one_v_all_class_list)}")

                if 'OneVsAll' not in list(unit_df):
                    print("\tadding 'OneVsAll'")
                    print("\treplacing all classes with 'OneVsAll'class column")
                    unit_df['class'] = one_v_all_class_list


            min_act = unit_df['hid_acts'].min()

            if normed_acts:
                if min_act >= 0.0:
                    print("\nnormalising activations")
                    this_unit_normed_acts = np.divide(unit_df['hid_acts'], unit_df['hid_acts'].max())
                    unit_df['normed'] = this_unit_normed_acts
                    print(unit_df.head())
                else:
                    print("\ncan not do normed acts on this unit")
                    normed_acts = False


            # # # did any items fail that were previously at zero
            print(f"\n\tsmallest activation on this layer was {min_act}")
            l_failed_df = unit_df[(unit_df['item_change'] == -1)]
            l_failed_df = l_failed_df.sort_values(by=['hid_acts'])

            min_failed_act = l_failed_df['hid_acts'].min()
            print(f"\n\tsmallest activation of items that failed after lesioning was {min_failed_act}")
            if min_failed_act == 0.0:
                fail_zero_df = l_failed_df.loc[l_failed_df['hid_acts'] == 0.0]
                fail_zero_count = len(fail_zero_df.index)
                print(f"\n\tfail_zero_df: {fail_zero_count} items\n\t{fail_zero_df.head()}")
                fail_zero_df.to_csv(f"{output_filename}_{gha_layer_name}_{unit}_fail_zero_df.csv", index=False)


            # # make plot of class changes
            # if plot_fails is True:
            if plot_class_change:
                class_prop_change = prop_change_df.iloc[:-1, unit].to_list()
                print(f"\n\tclass_prop_change: {class_prop_change}")

                # change scale if there are big changes
                class_change_x_min = -.5
                if min(class_prop_change) < class_change_x_min:
                    class_change_x_min = min(class_prop_change)

                class_change_x_max = .1
                if max(class_prop_change) > class_change_x_max:
                    class_change_x_max = max(class_prop_change)

                class_change_curve = sns.barplot(x=class_prop_change, y=class_labels, orient='h')
                class_change_curve.set_xlim([class_change_x_min, class_change_x_max])
                class_change_curve.axvline(0, color="k", clip_on=False)
                plt.subplots_adjust(left=0.15)  # just to fit the label 'automobile' on

                print(f'\nclass num: {class_prop_change.index(min(class_prop_change))}, '
                      f'class label: {class_labels[class_prop_change.index(min(class_prop_change))]}, '
                      f'class_val: {min(class_prop_change):.2f}'
                      )

                plt.title(f"{lesion_layer_and_unit}\n"
                          f"total change: {total_unit_change:.2f} "
                          f"max_class ({class_labels[class_prop_change.index(min(class_prop_change))]}): "
                          f"{min(class_prop_change):.2f}")
                plt.savefig(f"{output_filename}_{output_layer_and_unit}_class_prop_change.png")

                if test_run:
                    plt.show()

                plt.close()



            # # # # # # # # # # # #
            # # raincloud plots # #
            # # # # # # # # # # # #

            # # # plot title
            if plot_fails:
                title = f"Layer: {gha_layer_name} Unit: {unit}\nmax_class_drop: {max_class_drop_val:.2f} " \
                        f"({max_drop_class}), total change: {total_unit_change:.2f}\n" \
                        f"{main_sel_name}: {main_sel_val:.2f} ({main_sel_class})"

                if plot_type == "OneVsAll":
                    title = f"Layer: {gha_layer_name} Unit: {unit} class: {coi}\n" \
                            f"max_class_drop: {max_class_drop_val:.2f} ({max_drop_class}), " \
                            f"total change: {total_unit_change:.2f}" \
                            "\n{main_sel_name}: {main_sel_val:.2f} ({main_sel_class})"
            else:
                title = f"Layer: {gha_layer_name} Unit: {unit}\n" \
                        f"{main_sel_name}: {main_sel_val:.2f} ({main_sel_class})"

                if plot_type == "OneVsAll":
                    title = f"Layer: {gha_layer_name} Unit: {unit} class: {coi}\n" \
                            f"{main_sel_name}: {main_sel_val:.2f} ({main_sel_class})"
            print(f"\ntitle:\n{title}")

            # # # load main dataframe
            raincloud_data = unit_df
            # print(raincloud_data.head())

            plot_y_vals = "class"
            # use_this_ipc = items_per_cat
            use_this_ipc = df_ipc

            if plot_type is "OneVsAll":
                print("\t\n\n\nUSE OneVsAll mode")
                n_cats = 2
                items_per_coi = use_this_ipc[coi]
                other_items = sum(df_ipc.values()) - items_per_coi
                use_this_ipc = {0: other_items, 1: items_per_coi}
                print(f"\tcoi {coi}, items_per_cat {items_per_cat}")

            # # # choose colours
            use_colours = 'tab10'
            if 10 < n_cats < 21:
                use_colours = 'tab20'
            elif n_cats > 20:
                print("\tERROR - more classes than colours!?!?!?")
            sns.set_palette(palette=use_colours, n_colors=n_cats)

            # Make MULTI plot
            fig = plt.figure(figsize=(10, 5))
            gs = gridspec.GridSpec(1, 2, width_ratios=[1, 4])
            zeros_axis = plt.subplot(gs[0])
            rain_axis = plt.subplot(gs[1])

            # # # # # # # # # # # #
            # # make zeros plot # #
            # # # # # # # # # # # #

            # 1. get biggest class size (for max val of plot)
            max_class_size = max(use_this_ipc.values())
            print(f"\tmax_class_size: {max_class_size}")

            # 2. get list or dict of zeros per class
            zeros_dict = {}
            for k in range(n_cats):
                if plot_type is "OneVsAll":
                    plot_names = ["all_others", f"class_{coi}"]
                    this_name = plot_names[k]
                    this_class = unit_df.loc[unit_df['OneVsAll'] == k]
                    zero_count = 0 - (this_class['hid_acts'] == 0).sum()
                    zeros_dict[this_name] = zero_count
                else:
                    this_class = unit_df.loc[unit_df['class'] == k]
                    zero_count = 0 - (this_class['hid_acts'] == 0).sum()
                    zeros_dict[k] = zero_count

            # zd_classes = list(zeros_dict.keys())
            # zd_classes = list(lesion_dict['data_info']['cat_names'].values())
            zd_zero_count = list(zeros_dict.values())

            if verbose:
                print(f"\n\tzeros_dict:{zeros_dict.values()}, use_this_ipc:{use_this_ipc.values()}")

            zd_zero_perc = [x / y * 100 if y else 0 for x, y in zip(zeros_dict.values(), use_this_ipc.values())]

            zd_data = {"class": class_labels, "zero_count": zd_zero_count, "zero_perc": zd_zero_perc}

            zeros_dict_df = pd.DataFrame.from_dict(data=zd_data)

            # zero_plot
            sns.catplot(x="zero_perc", y="class", data=zeros_dict_df, kind="bar", orient='h', ax=zeros_axis)

            zeros_axis.set_xlabel("% at zero (height reflects n items)")

            zeros_axis.set_xlim([-100, 0])

            # # set width of bar to reflect class size
            new_heights = [x / max_class_size for x in use_this_ipc.values()]
            print(f"\tuse_this_ipc: {use_this_ipc}\n\tnew_heights: {new_heights}")

            # def change_height(zeros_axis, new_value):
            patch_count = 0
            for patch in zeros_axis.patches:
                current_height = patch.get_height()
                make_new_height = current_height * new_heights[patch_count]
                diff = current_height - make_new_height

                if new_heights[patch_count] < 1.0:
                    # print("{}. current_height {}, new_height: {}".format(patch, current_height, make_new_height))

                    # # change the bar height
                    patch.set_height(make_new_height)

                    # # recenter the bar
                    patch.set_y(patch.get_y() + diff * .65)

                patch_count = patch_count + 1


            zeros_axis.set_xticklabels(['100', '50', ''])
            # zeros_axis.xaxis.set_major_locator(plt.MaxNLocator(1))
            plt.close()

            # # # # # # # # #
            # # raincloud # #
            # # # # # # # # #

            data_values = "hid_acts"  # float
            if normed_acts:
                data_values = 'normed'
            data_class = plot_y_vals  # class
            orientation = "h"  # orientation

            # cloud_plot
            pt.half_violinplot(data=raincloud_data, bw=.1, linewidth=.5, cut=0., width=1, inner=None,
                               orient=orientation, x=data_values, y=data_class, scale="count")  # scale="area"

            """# # rain_drops - plot 3 separate plots so that they are interesting items are ontop of pass-pass
            # # zorder is order in which items are printed
            # # item_change: 1 ('grey') passed before and after lesioning
            # # -1 ('red') passed in full model but failed when lesioned
            # # 2 ('green') failed in full model but passed in lesioning"""
            fail_palette = {1: "silver", -1: "red", 2: "green", 0: "orange"}


            # # separate rain drops for pass pass,
            pass_pass_df = unit_df[(unit_df['item_change'] == 1)]
            pass_pass_drops = sns.stripplot(data=pass_pass_df, x=data_values, y=data_class, jitter=1, zorder=1,
                                            size=2, orient=orientation)  # , hue='item_change', palette=fail_palette)

            if plot_fails is True:

                '''I'm not using this atm, but if I want to plot items that originally failed and later passed'''
                # # separate raindrop for fail pass
                # fail_pass_df = unit_df[(unit_df['item_change'] == 2)]
                # if not fail_pass_df.empty:
                #     fail_pass_drops = sns.stripplot(data=fail_pass_df, x=data_values, y=data_class, jitter=1,
                #                                     zorder=3, size=4, orient=orientation, hue='item_change',
                #                                     palette=fail_palette, edgecolor='gray', linewidth=.4, marker='s',
                #                                     label='')

                # # separate raindrops for pass fail
                if not l_failed_df.empty:
                    # pass_fail_drops
                    sns.stripplot(data=l_failed_df, x=data_values, y=data_class, jitter=1, zorder=4, size=4,
                                  orient=orientation, hue='item_change', palette=fail_palette, edgecolor='white',
                                  linewidth=.4, marker='s')

            # box_plot
            sns.boxplot(data=raincloud_data, color="gray", orient=orientation, width=.15, x=data_values,
                        y=data_class, zorder=2, showbox=False,
                        # boxprops={'facecolor': 'none', "zorder": 2},
                        showfliers=False, showcaps=False,
                        whiskerprops={'linewidth': .01, "zorder": 2}, saturation=1,
                        # showwhiskers=False,
                        medianprops={'linewidth': .01, "zorder": 2},
                        showmeans=True,
                        meanprops={"marker": "*", "markerfacecolor": "white", "markeredgecolor": "black"}
                        )

            # # Finalize the figure
            rain_axis.set_xlabel("Unit activations")
            if normed_acts:
                rain_axis.set_xlabel("Unit activations (normalised)")

            # new_legend_text = ['l_passed', 'l_failed']
            new_legend_text = ['l_failed']

            leg = pass_pass_drops.axes.get_legend()
            if leg:
                # in here because leg is None if no items changed when this unit was lesioned
                for t, l in zip(leg.texts, new_legend_text):
                    t.set_text(l)

            # # hid ticks and labels from rainplot
            plt.setp(rain_axis.get_yticklabels(), visible=False)
            rain_axis.axes.get_yaxis().set_visible(False)

            # # put plots together
            max_activation = max(this_unit_hid_acts)
            min_activation = min(this_unit_hid_acts)
            if normed_acts:
                max_activation = max(this_unit_normed_acts)
                min_activation = min(this_unit_normed_acts)

            max_x_val = max_activation * 1.05
            layer_act_func = None
            for k, v in lesion_dict['model_info']['layers']['hid_layers'].items():
                if v['name'] == gha_layer_name:
                    layer_act_func = v['act_func']
                    break
            if layer_act_func in ['relu', 'Relu', 'ReLu']:
                min_x_val = 0
            elif min_activation > 0.0:
                min_x_val = 0
            else:
                min_x_val = min_activation

            rain_axis.set_xlim([min_x_val, max_x_val])
            rain_axis.get_shared_y_axes().join(zeros_axis, rain_axis)
            fig.subplots_adjust(wspace=0)

            fig.suptitle(title, fontsize=12).set_position([.5, 1.0])  # .set_bbox([])  #

            # # add y axis back onto rainplot
            plt.axvline(x=min_x_val, linestyle="-", color='black', )

            # # add marker for max informedness
            if 'info' in coi_measure:
                if old_sel_dict:
                    normed_info_thr = sel_dict['sel_info'][gha_layer_name][unit]['max']['thr_informed']
                else:
                    print(sel_info[gha_layer_name][unit]['max'])
                    normed_info_thr = sel_info[gha_layer_name][unit]['max']['max_info_thr']

                if normed_acts:
                    best_info_thr = normed_info_thr
                else:
                    # unnormalise it
                    best_info_thr = normed_info_thr * max(this_unit_hid_acts)
                print(f"\tbest_info_thr: {best_info_thr}")
                plt.axvline(x=best_info_thr, linestyle="--", color='grey')

            # sns.despine(right=True)

            if plot_type is "OneVsAll":
                plt.savefig(f"{output_filename}_{gha_layer_name}_{unit}_cat{coi}_raincloud.png")

            else:
                plt.savefig(f"{output_filename}_{gha_layer_name}_{unit}_raincloud.png")

            if test_run:
                plt.show()

            print("\n\tplot finished\n")

            # # clear for next round
            plt.close()

    # # plt.show()
    print("End of script")
Exemple #7
0
def simple_plot_rnn(gha_dict_path,
                    plot_what='all',
                    measure='b_sel',
                    letter_sel=False,
                    correct_items_only=True,
                    verbose=False,
                    test_run=False,
                    show_plots=False):
    """
    
    :param gha_dict_path: or gha_dict
    :param plot_what: 'all' or 'highlights' or dict[layer_names][units][timesteps] 
    :param measure: selectivity measure to focus on if hl_dict provided
    :param letter_sel: focus on level of words or letters
    :param correct_items_only: remove items that were incorrect
    :param verbose:
    :param test_run: just 3 plots
    :param show_plots: 

    :return: 
    """

    print("\n**** running simple_plot_rnn() ****")

    if os.path.isfile(gha_dict_path):
        # # use gha-dict_path to get exp_cond_gha_path, gha_dict_name,
        exp_cond_gha_path, gha_dict_name = os.path.split(gha_dict_path)
        os.chdir(exp_cond_gha_path)

        # # part 1. load dict from study (should run with sim, GHA or sel dict)
        gha_dict = load_dict(gha_dict_path)

    elif type(gha_dict_path) is dict:
        gha_dict = gha_dict_path
        exp_cond_gha_path = os.getcwd()

    else:
        raise FileNotFoundError(gha_dict_path)

    if verbose:
        focussed_dict_print(gha_dict, 'gha_dict')

    # get topic_info from dict
    output_filename = gha_dict["topic_info"]["output_filename"]
    if letter_sel:
        output_filename = f"{output_filename}_lett"

    # # where to save files
    plots_folder = 'plots'
    cond_name = gha_dict['topic_info']['output_filename']
    condition_path = find_path_to_dir(long_path=exp_cond_gha_path,
                                      target_dir=cond_name)
    plots_path = os.path.join(condition_path, plots_folder)
    if not os.path.exists(plots_path):
        os.makedirs(plots_path)
    # os.chdir(plots_path)

    if verbose:
        print(f"\noutput_filename: {output_filename}")
        print(f"plots_path (to save): {plots_path}")
        print(f"os.getcwd(): {os.getcwd()}")

    # # get data info from dict
    n_words = gha_dict["data_info"]["n_cats"]
    n_letters = gha_dict["data_info"]["X_size"]
    if verbose:
        print(f"the are {n_words} word classes")

    if letter_sel:
        n_letters = gha_dict['data_info']["X_size"]
        n_words = n_letters
        print(
            f"the are {n_letters} letters classes\nn_words now set as n_letters"
        )

        letter_id_dict = load_dict(
            os.path.join(gha_dict['data_info']['data_path'],
                         'letter_id_dict.txt'))
        print(f"\nletter_id_dict:\n{letter_id_dict}")

    # # get model info from dict
    model_dict = gha_dict['model_info']['config']
    if verbose:
        focussed_dict_print(model_dict, 'model_dict')

    timesteps = gha_dict['model_info']["overview"]["timesteps"]
    vocab_dict = load_dict(
        os.path.join(gha_dict['data_info']["data_path"],
                     gha_dict['data_info']["vocab_dict"]))
    '''Part 2 - load y, sort out incorrect resonses'''
    print("\n\nPart 2: loading labels")
    # # load y_labels to go with hid_acts and item_correct for sequences
    if 'seq_corr_list' in gha_dict['GHA_info']['scores_dict']:
        n_seqs = gha_dict['GHA_info']['scores_dict']['n_seqs']
        n_seq_corr = gha_dict['GHA_info']['scores_dict']['n_seq_corr']
        n_incorrect = n_seqs - n_seq_corr

        test_label_seq_name = gha_dict['GHA_info']['y_data_path']
        seqs_corr = gha_dict['GHA_info']['scores_dict']['seq_corr_list']

        test_label_seqs = np.load(f"{test_label_seq_name}labels.npy")

        if verbose:
            print(f"test_label_seqs: {np.shape(test_label_seqs)}")
            print(f"seqs_corr: {np.shape(seqs_corr)}")
            print(f"n_seq_corr: {n_seq_corr}")

        if letter_sel:
            # # get 1hot item vectors for 'words' and 3 hot for letters
            '''Always use serial_recall True. as I want a separate 1hot vector for each item.
            Always use x_data_type 'local_letter_X' as I want 3hot vectors'''
            y_letters = []
            y_words = []
            for this_seq in test_label_seqs:
                get_letters, get_words = get_X_and_Y_data_from_seq(
                    vocab_dict=vocab_dict,
                    seq_line=this_seq,
                    serial_recall=True,
                    end_seq_cue=False,
                    x_data_type='local_letter_X')
                y_letters.append(get_letters)
                y_words.append(get_words)

            y_letters = np.array(y_letters)
            y_words = np.array(y_words)
            if verbose:
                print(f"\ny_letters: {type(y_letters)}  {np.shape(y_letters)}")
                print(f"y_words: {type(y_words)}  {np.shape(y_words)}")

        y_df_headers = [f"ts{i}" for i in range(timesteps)]
        y_scores_df = pd.DataFrame(data=test_label_seqs, columns=y_df_headers)
        y_scores_df['full_model'] = seqs_corr
        if verbose:
            print(f"\ny_scores_df: {y_scores_df.shape}\n{y_scores_df.head()}")

    # # if not sequence data, load y_labels to go with hid_acts and item_correct for items
    elif 'item_correct_name' in gha_dict['GHA_info']['scores_dict']:
        # # load item_correct (y_data)
        item_correct_name = gha_dict['GHA_info']['scores_dict'][
            'item_correct_name']
        # y_df = pd.read_csv(item_correct_name)
        y_scores_df = nick_read_csv(item_correct_name)
    """# # get rid of incorrect items if required"""
    print("\n\nRemoving incorrect responses")
    # # # get values for correct/incorrect items (1/0 or True/False)
    item_correct_list = y_scores_df['full_model'].tolist()
    full_model_values = list(set(item_correct_list))

    correct_symbol = 1
    if len(full_model_values) != 2:
        TypeError(
            f"TYPE_ERROR!: what are the scores/acc for items? {full_model_values}"
        )
    if 1 not in full_model_values:
        if True in full_model_values:
            correct_symbol = True
        else:
            TypeError(
                f"TYPE_ERROR!: what are the scores/acc for items? {full_model_values}"
            )

    print(f"len(full_model_values): {len(full_model_values)}")
    print(f"correct_symbol: {correct_symbol}")

    # # i need to check whether this analysis should include incorrect items (True/False)
    gha_incorrect = gha_dict['GHA_info']['gha_incorrect']

    # # get item indeces for correct and incorrect items
    item_index = list(range(n_seq_corr))

    incorrect_items = []
    correct_items = []
    for index in range(len(item_correct_list)):
        if item_correct_list[index] == 0:
            incorrect_items.append(index)
        else:
            correct_items.append(index)
    if correct_items_only:
        item_index == correct_items

    if gha_incorrect:
        if correct_items_only:
            if verbose:
                print(
                    "\ngha_incorrect: True (I have incorrect responses)\n"
                    "correct_items_only: True (I only want correct responses)")
                print(
                    f"remove {n_incorrect} incorrect from hid_acts & output using y_scores_df."
                )
                print("use y_correct for y_df")

            y_correct_df = y_scores_df.loc[y_scores_df['full_model'] ==
                                           correct_symbol]
            y_df = y_correct_df

            mask = np.ones(shape=len(seqs_corr), dtype=bool)
            mask[incorrect_items] = False
            test_label_seqs = test_label_seqs[mask]

            if letter_sel:
                y_letters = y_letters[mask]

        else:
            if verbose:
                print("\ngha_incorrect: True (I have incorrect responses)\n"
                      "correct_items_only: False (I want incorrect responses)")
                print(
                    "no changes needed - don't remove anything from hid_acts, output and "
                    "use y scores as y_df")
    else:
        if correct_items_only:
            if verbose:
                print(
                    "\ngha_incorrect: False (I only have correct responses)\n"
                    "correct_items_only: True (I only want correct responses)")
                print(
                    "no changes needed - don't remove anything from hid_acts or output.  "
                    "Use y_correct as y_df")
            y_correct_df = y_scores_df.loc[y_scores_df['full_model'] ==
                                           correct_symbol]
            y_df = y_correct_df
        else:
            if verbose:
                print(
                    "\ngha_incorrect: False (I only have correct responses)\n"
                    "correct_items_only: False (I want incorrect responses)")
                raise TypeError(
                    "I can not complete this as desried"
                    "change correct_items_only to True"
                    "for analysis  - don't remove anything from hid_acts, output and "
                    "use y scores as y_df")

            # correct_items_only = True

    if verbose is True:
        print(f"\ny_df: {y_df.shape}\n{y_df.head()}")
        print(f"\ntest_label_seqs: {np.shape(test_label_seqs)}"
              )  # \n{test_label_seqs}")
        # if letter_sel:
        #     y_letters = np.asarray(y_letters)
        #     print(f"y_letters: {np.shape(y_letters)}")  # \n{test_label_seqs}")

    n_correct, timesteps = np.shape(test_label_seqs)
    corr_test_seq_name = f"{output_filename}_{n_correct}_corr_test_label_seqs.npy"
    np.save(corr_test_seq_name, test_label_seqs)
    corr_test_letters_name = 'not_processed_yet'
    if letter_sel:
        corr_test_letters_name = f"{output_filename}_{n_correct}_corr_test_letter_seqs.npy"
        np.save(corr_test_letters_name, y_letters)

    # # get items per class
    IPC_dict = seq_items_per_class(label_seqs=test_label_seqs,
                                   vocab_dict=vocab_dict)
    focussed_dict_print(IPC_dict, 'IPC_dict')
    corr_test_IPC_name = f"{output_filename}_{n_correct}_corr_test_IPC.pickle"
    with open(corr_test_IPC_name, "wb") as pickle_out:
        pickle.dump(IPC_dict, pickle_out, protocol=pickle.HIGHEST_PROTOCOL)

    # # how many times is each item represented at each timestep.
    word_p_class_p_ts = IPC_dict['word_p_class_p_ts']
    letter_p_class_p_ts = IPC_dict['letter_p_class_p_ts']

    for i in range(timesteps):
        n_words_p_ts = len(word_p_class_p_ts[f"ts{i}"].keys())
        n_letters_p_ts = len(letter_p_class_p_ts[f"ts{i}"].keys())

        print(
            f"ts{i}) words:{n_words_p_ts}/{n_words}\tletters: {n_letters_p_ts}/{n_letters}"
        )
        # print(word_p_class_p_ts[f"ts{i}"].keys())

    # # sort plot_what
    print(f"\nplotting: {plot_what}")

    if type(plot_what) is str:
        if plot_what == 'all':
            hl_dict = dict()

            # # add model full model structure to hl_dict
            if letter_sel:
                sel_per_unit_dict_path = f'{exp_cond_gha_path}/{cond_name}_lett_sel_per_unit.pickle'
            else:
                sel_per_unit_dict_path = f'{exp_cond_gha_path}/{cond_name}_sel_per_unit.pickle'

            if os.path.isfile(sel_per_unit_dict_path):
                sel_per_unit_dict = load_dict(sel_per_unit_dict_path)

                for layer in list(sel_per_unit_dict.keys()):
                    hl_dict[layer] = dict()
                    for unit in sel_per_unit_dict[layer].keys():
                        hl_dict[layer][unit] = dict()
                        for ts in sel_per_unit_dict[layer][unit].keys():
                            if measure in sel_per_unit_dict[layer][unit][ts]:
                                class_sel_dict = sel_per_unit_dict[layer][
                                    unit][ts][measure]
                                key_max = max(class_sel_dict,
                                              key=class_sel_dict.get)
                                val_max = class_sel_dict[key_max]
                                hl_entry = (measure, val_max, key_max,
                                            'rank_1')
                                hl_dict[layer][unit][ts] = list()
                                hl_dict[layer][unit][ts].append(hl_entry)

        elif os.path.isfile(plot_what):
            hl_dict = load_dict(plot_what)
            """plot_what should be:\n
                    i. 'all'\n
                    ii. path to highlights dict\n
                    iii. highlights_dict\n
                    iv. dict with structure [layers][units][timesteps]"""

    elif type(plot_what) is dict:
        hl_dict = plot_what
    else:
        raise ValueError("plot_what should be\n"
                         "i. 'all'\n"
                         "ii. path to highlights dict\n"
                         "iii. highlights_dict\n"
                         "iv. dict with structure [layers][units][timesteps]")

    if hl_dict:
        focussed_dict_print(hl_dict, 'hl_dict')
    '''save results
    either make a new empty place to save.
    or load previous version and get the units I have already completed'''
    os.chdir(plots_path)
    '''
    part 3   - get gha for each unit
    '''
    loop_gha = loop_thru_acts(gha_dict_path=gha_dict_path,
                              correct_items_only=correct_items_only,
                              letter_sel=letter_sel,
                              verbose=verbose,
                              test_run=test_run)

    test_run_counter = 0
    for index, unit_gha in enumerate(loop_gha):

        print(f"\nindex: {index}")

        # print(f"\n\n{index}:\n{unit_gha}\n")
        sequence_data = unit_gha["sequence_data"]
        y_1hot = unit_gha["y_1hot"]
        layer_name = unit_gha["layer_name"]
        unit_index = unit_gha["unit_index"]
        timestep = unit_gha["timestep"]
        ts_name = f"ts{timestep}"
        item_act_label_array = unit_gha["item_act_label_array"]

        # # only plot units of interest according to hl dict
        if hl_dict:
            if layer_name not in hl_dict:
                print(f"{layer_name} not in hl_dict")
                continue
            if unit_index not in hl_dict[layer_name]:
                print(f"unit {unit_index} not in hl_dict[{layer_name}]")
                continue
            if ts_name not in hl_dict[layer_name][unit_index]:
                print(f"{ts_name} not in hl_dict[{layer_name}][{unit_index}]")
                continue

            # # list comp version fails so use for loop
            # unit_hl_info = [x for x in hl_dict[layer_name][unit_index][ts_name]
            #                 if x[0] == measure]
            unit_hl_info = []
            print('check line 377')
            for x in hl_dict[layer_name][unit_index][ts_name]:
                print(x)
                if x[0] == measure:
                    unit_hl_info.append(x)

            if len(unit_hl_info) == 0:
                print(
                    f"{measure} not in hl_dict[{layer_name}][{unit_index}][{ts_name}]"
                )
                continue

            if 'ts_invar' in hl_dict[layer_name][unit_index]:
                if measure not in hl_dict[layer_name][unit_index]['ts_invar']:
                    print(
                        f"{measure} not in hl_dict[{layer_name}][{unit_index}]['ts_invar']"
                    )
                    continue

            if test_run:
                if test_run_counter == 3:
                    break
                test_run_counter += 1

            unit_hl_info = list(unit_hl_info[0])

            print(f"plotting {layer_name} {unit_index} {ts_name} "
                  f"{unit_hl_info}")

            print(f"\nsequence_data: {sequence_data}")
            print(f"y_1hot: {y_1hot}")
            print(f"unit_index: {unit_index}")
            print(f"timestep: {timestep}")
            print(f"ts_name: {ts_name}")

            # # selective_for_what
            sel_idx = unit_hl_info[2]
            if letter_sel:
                sel_for = 'letter'
                sel_item = letter_id_dict[sel_idx]
            else:
                sel_for = 'word'
                sel_item = vocab_dict[sel_idx]['word']

            # # add in sel item
            unit_hl_info.insert(3, sel_item)

            # # change rank to int
            rank_str = unit_hl_info[4]
            unit_hl_info[4] = int(rank_str[5:])

            # hl_text = f'measure\tvalue\tclass\t{sel_for}\trank\n'
            hl_keys = [
                'measure: ', 'value: ', 'label: ', f'{sel_for}: ', 'rank: '
            ]
            hl_text = ''
            for idx, info in enumerate(unit_hl_info):
                key = hl_keys[idx]
                str_info = str(info)
                # hl_text = ''.join([hl_text, str_info[1:-1], '\n'])
                hl_text = ''.join([hl_text, key, str_info, '\n'])

            print(f"\nhl_text: {hl_text}")

        else:
            print("no hl_dict")

        # #  make df
        this_unit_acts = pd.DataFrame(data=item_act_label_array,
                                      columns=['item', 'activation', 'label'])
        this_unit_acts_df = this_unit_acts.astype({
            'item': 'int32',
            'activation': 'float',
            'label': 'int32'
        })

        if letter_sel:
            y_letters_1ts = np.array(y_letters[:, timestep])
            print(f"y_letters_1ts: {np.shape(y_letters_1ts)}")
            # print(f"y_letters_1ts: {y_letters_1ts}")

        # if test_run:
        # # get word ids to check results more easily.
        unit_ts_labels = this_unit_acts_df['label'].tolist()
        # print(f"unit_ts_labels:\n{unit_ts_labels}")

        seq_words_df = spell_label_seqs(
            test_label_seqs=np.asarray(unit_ts_labels),
            vocab_dict=vocab_dict,
            save_csv=False)
        seq_words_list = seq_words_df.iloc[:, 0].tolist()
        # print(f"seq_words_df:\n{seq_words_df}")
        this_unit_acts_df['words'] = seq_words_list
        # print(f"this_unit_acts_df:\n{this_unit_acts_df.head()}")

        # # get labels for selective item
        if letter_sel:
            sel_item_list = y_letters_1ts[:, sel_idx]

        else:
            sel_item_list = [1 if x == sel_item else 0 for x in seq_words_list]
        this_unit_acts_df['sel_item'] = sel_item_list

        # sort by ascending word labels
        this_unit_acts_df = this_unit_acts_df.sort_values(by='words',
                                                          ascending=True)

        if verbose is True:
            print(f"\nthis_unit_acts_df: {this_unit_acts_df.shape}\n")
            print(f"this_unit_acts_df:\n{this_unit_acts_df.head()}")

        # # make simple plot
        title = f"{layer_name} unit{unit_index} {ts_name} (of {timesteps})"

        print(f"title: {title}")

        if hl_dict:
            gridkw = dict(width_ratios=[2, 1])
            fig, (spotty_axis, text_box) = plt.subplots(1,
                                                        2,
                                                        gridspec_kw=gridkw)
            sns.catplot(x='activation',
                        y="words",
                        hue='sel_item',
                        data=this_unit_acts_df,
                        ax=spotty_axis,
                        orient='h',
                        kind="strip",
                        jitter=1,
                        dodge=True,
                        linewidth=.5,
                        palette="Set2",
                        marker="D",
                        edgecolor="gray")  # , alpha=.25)
            text_box.text(0.0, -0.01, hl_text, fontsize=10, clip_on=False)
            text_box.axes.get_yaxis().set_visible(False)
            text_box.axes.get_xaxis().set_visible(False)
            text_box.patch.set_visible(False)
            text_box.axis('off')
            spotty_axis.get_legend().set_visible(False)
            spotty_axis.set_xlabel("Unit activations")
            fig.suptitle(title)
            plt.close()  # extra blank plot
        else:
            sns.catplot(x='activation',
                        y="words",
                        data=this_unit_acts_df,
                        orient='h',
                        kind="strip",
                        jitter=1,
                        dodge=True,
                        linewidth=.5,
                        palette="Set2",
                        marker="D",
                        edgecolor="gray")  # , alpha=.25)
            plt.xlabel("Unit activations")
            plt.suptitle(title)
            plt.tight_layout(rect=[0, 0.03, 1, 0.90])

        if letter_sel:
            save_name = f"{plots_path}/" \
                        f"{output_filename}_{layer_name}_{unit_index}_{ts_name}" \
                        f"_{measure}_lett.png"
        else:
            save_name = f"{plots_path}/" \
                        f"{output_filename}_{layer_name}_{unit_index}_{ts_name}" \
                        f"_{measure}_word.png"
        plt.savefig(save_name)
        if show_plots:
            plt.show()
        plt.close()

    print("\nend of simple_plot_rnn script")
Exemple #8
0
def ff_gha(sim_dict_path,
           get_classes=("Conv2D", "Dense", "Activation"),
           gha_incorrect=True,
           use_dataset='train_set',
           save_2d_layers=True,
           save_4d_layers=False,
           exp_root='/home/nm13850/Documents/PhD/python_v2/experiments/',
           verbose=False,
           test_run=False):
    """
    gets activations from hidden units.

    1. load simulation dict (with data info) (*_load_dict.pickle)
        sim_dict can be fed in from sim script, or loaded separately
    2. load model - get structure and details
    3. run dataset through once, recording accuracy per item/class
    4. run on 2nd model to get hid acts

    :param sim_dict_path: path to the dictionary for this experiment condition
    :param get_classes: which types of layer are we interested in?
    :param gha_incorrect: GHA for ALL items (True) or just correct items (False)
    :param use_dataset: GHA for train/test data
    :param save_2d_layers: get 1 value per kernel for conv/pool layers
    :param save_4d_layers: keep original shape of conv/pool layers (for other analysis maybe?)
    :param exp_root: root to save experiments
    :param verbose:
    :param test_run: Set test = True to just do one unit per layer


    :return: dict with hid acts per layer.  saved as dict so different shaped arrays don't matter too much
    """

    print('**** ff_gha GHA() ****')

    # # # PART 1 # # #
    # # load details from dict
    if type(sim_dict_path) is dict:
        sim_dict = sim_dict_path
        full_exp_cond_path = sim_dict['topic_info']['exp_cond_path']

    elif os.path.isfile(sim_dict_path):
        print(f"sim_dict_path: {sim_dict_path}")
        sim_dict = load_dict(sim_dict_path)
        full_exp_cond_path, sim_dict_name = os.path.split(sim_dict_path)

    elif os.path.isfile(os.path.join(exp_root, sim_dict_path)):
        sim_dict_path = os.path.join(exp_root, sim_dict_path)
        print(f"sim_dict_path: {sim_dict_path}")
        sim_dict = load_dict(sim_dict_path)
        full_exp_cond_path, sim_dict_name = os.path.split(sim_dict_path)
    else:
        raise FileNotFoundError(sim_dict_path)

    os.chdir(full_exp_cond_path)
    print(f"set_path to full_exp_cond_path: {full_exp_cond_path}")

    # exp_dir, _ = os.path.split(exp_cond_path)

    focussed_dict_print(sim_dict, 'sim_dict')

    # # # load datasets

    # # check for training data
    if use_dataset in sim_dict['data_info']:
        x_data_path = os.path.join(
            sim_dict['data_info']['data_path'],
            sim_dict['data_info'][use_dataset]['X_data'])
        y_data_path = os.path.join(
            sim_dict['data_info']['data_path'],
            sim_dict['data_info'][use_dataset]['Y_labels'])
        print(
            f"\nloading {use_dataset}\nx_data_path: {x_data_path}\ny_data_path: {y_data_path}"
        )
    elif use_dataset == 'train_set':
        x_data_path = os.path.join(sim_dict['data_info']['data_path'],
                                   sim_dict['data_info']['X_data'])
        y_data_path = os.path.join(sim_dict['data_info']['data_path'],
                                   sim_dict['data_info']['Y_labels'])
        print(f"\nloading {use_dataset} (only dset available):")
    else:
        print(f"\nERROR! requested dataset ({use_dataset}) not found in dict:")
        focussed_dict_print(sim_dict['data_info'], "sim_dict['data_info']")
        if 'X_data' in sim_dict['data_info']:
            print(f"\nloading only dset available:")
            x_data_path = os.path.join(sim_dict['data_info']['data_path'],
                                       sim_dict['data_info']['X_data'])
            y_data_path = os.path.join(sim_dict['data_info']['data_path'],
                                       sim_dict['data_info']['Y_labels'])

    x_data = load_x_data(x_data_path)
    y_df, y_label_list = load_y_data(y_data_path)

    n_cats = sim_dict['data_info']["n_cats"]

    # # data preprocessing
    # # if network is cnn but data is 2d (e.g., MNIST)
    # # old version
    # if len(np.shape(x_data)) != 4:
    #     if sim_dict['model_info']['overview']['model_type'] == 'cnn':
    #         width, height = sim_dict['data_info']['image_dim']
    #         x_data = x_data.reshape(x_data.shape[0], width, height, 1)
    #         print(f"\nRESHAPING x_data to: {np.shape(x_data)}")

    # new version
    print(f"\ninput shape: {np.shape(x_data)}")
    if len(np.shape(x_data)) == 4:
        image_dim = sim_dict['data_info']['image_dim']
        n_items, width, height, channels = np.shape(x_data)
    else:
        # # this is just for MNIST
        if sim_dict['model_info']['overview']['model_type'] in ['cnn', 'cnns']:
            print("reshaping mnist for cnn")
            width, height = sim_dict['data_info']['image_dim']
            x_data = x_data.reshape(x_data.shape[0], width, height, 1)
            print(f"\nRESHAPING x_data to: {np.shape(x_data)}")

        if sim_dict['model_info']['overview']['model_type'] == 'mlps':
            if len(np.shape(x_data)) > 2:
                print(
                    f"reshaping image data from {len(np.shape(x_data))}d to 2d for mlp"
                )
                x_data = np.reshape(
                    x_data,
                    (x_data.shape[0], x_data.shape[1] * x_data.shape[2]))

                print(f"\nNEW input shape: {np.shape(x_data)}")
                x_size = np.shape(x_data)[1]
                print(f"NEW x_size: {x_size}")

    # Preprocess the data (these are NumPy arrays)
    if x_data.dtype == "uint8":
        print(f"converting input data from {x_data.dtype} to float32")
        max_x = np.amax(x_data)
        x_data = x_data.astype("float32") / max_x
        print(f"x_data.dtype: {x_data.dtype}")

    # Output files
    output_filename = sim_dict["topic_info"]["output_filename"]
    print(f"\nOutput file: {output_filename}")

    # # # # PART 2 # # #
    print("\n**** THE MODEL ****")
    model_name = sim_dict['model_info']['overview']['trained_model']
    model_path = os.path.join(full_exp_cond_path, model_name)
    loaded_model = load_model(model_path)
    model_details = loaded_model.get_config()
    print_nested_round_floats(model_details, 'model_details')

    n_layers = len(model_details['layers'])
    model_dict = dict()

    # # turn off "trainable" and get useful info
    for layer in range(n_layers):
        # set to not train
        model_details['layers'][layer]['config']['trainable'] = 'False'

        if verbose:
            print(f"Model layer {layer}: {model_details['layers'][layer]}")

        # # get useful info
        layer_dict = {
            'layer': layer,
            'name': model_details['layers'][layer]['config']['name'],
            'class': model_details['layers'][layer]['class_name']
        }

        if 'units' in model_details['layers'][layer]['config']:
            layer_dict['units'] = model_details['layers'][layer]['config'][
                'units']
        if 'activation' in model_details['layers'][layer]['config']:
            layer_dict['act_func'] = model_details['layers'][layer]['config'][
                'activation']
        if 'filters' in model_details['layers'][layer]['config']:
            layer_dict['filters'] = model_details['layers'][layer]['config'][
                'filters']
        if 'kernel_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'kernel_size'][0]
        if 'pool_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'pool_size'][0]
        if 'strides' in model_details['layers'][layer]['config']:
            layer_dict['strides'] = model_details['layers'][layer]['config'][
                'strides'][0]
        if 'rate' in model_details['layers'][layer]['config']:
            layer_dict["rate"] = model_details['layers'][layer]['config'][
                'rate']

        # # set and save layer details
        model_dict[layer] = layer_dict

    # # my model summary
    model_df = pd.DataFrame.from_dict(
        data=model_dict,
        orient='index',
        columns=[
            'layer', 'name', 'class', 'act_func', 'units', 'filters', 'size',
            'strides', 'rate'
        ],
    )

    # # just classes of layers specified in get_classes
    key_layers_df = model_df.loc[model_df['class'].isin(get_classes)]
    key_layers_df.reset_index(inplace=True)
    del key_layers_df['index']
    key_layers_df.index.name = 'index'
    key_layers_df = key_layers_df.drop(columns=['size', 'strides', 'rate'])

    # # add column ('n_units_filts')to say how many things needs gha per layer (number of units or filters)
    # # add zeros to rows with no units or filters
    key_layers_df.loc[:, 'n_units_filts'] = key_layers_df.units.fillna(
        0) + key_layers_df.filters.fillna(0)

    print(f"\nkey_layers_df:\n{key_layers_df}")

    key_layers_df.loc[:,
                      "n_units_filts"] = key_layers_df["n_units_filts"].astype(
                          int)

    # # get to total number of units or filters in key layers of the network
    key_n_units_fils = sum(key_layers_df['n_units_filts'])

    print(f"\nkey_layers_df:\n{key_layers_df.head()}")
    # print("\nkey_layers_df:\n{}".format(key_layers_df))
    print(f"key_n_units_fils: {key_n_units_fils}")
    '''i currently get output layer, make sure I keep this in to make sure I can do class correlation'''

    # # # set dir to save gha stuff # # #
    hid_act_items = 'all'
    if not gha_incorrect:
        hid_act_items = 'correct'

    gha_folder = f'{hid_act_items}_{use_dataset}_gha'

    if test_run:
        gha_folder = os.path.join(gha_folder, 'test')
    gha_path = os.path.join(full_exp_cond_path, gha_folder)

    if not os.path.exists(gha_path):
        os.makedirs(gha_path)
    os.chdir(gha_path)
    print(f"saving hid_acts to: {gha_path}")

    # # # PART 3 get_scores() # # #
    predicted_outputs = loaded_model.predict(x_data)

    item_correct_df, scores_dict, incorrect_items = get_scores(
        predicted_outputs,
        y_df,
        output_filename,
        save_all_csvs=True,
        verbose=True)

    if verbose:
        focussed_dict_print(scores_dict, 'Scores_dict')

    # # # PART 4 # # #
    print("\n**** REMOVE INCORRECT FROM X DATA ****")
    mask = np.ones(len(x_data), dtype=bool)
    mask[incorrect_items] = False
    x_correct = copy.copy(x_data[mask])

    gha_items = x_correct
    if gha_incorrect:  # If I want ALL items including those classified incorrectly
        gha_items = x_data
    print(
        f"gha_items: (incorrect items={gha_incorrect}) {np.shape(gha_items)}")

    # # PART 5
    print("\n**** Get Hidden unit activations ****")
    hid_act_2d_dict = dict(
    )  # # to use to get 2d hid acts (e.g., means from 4d layers)
    hid_act_any_d_dict = dict(
    )  # # to use to get all hid acts (e.g., both 2d and 4d layers)

    # # loop through key layers df
    gha_key_layers = []
    for index, row in key_layers_df.iterrows():
        if test_run:
            if index > 3:
                continue

        layer_number, layer_name, layer_class = row['layer'], row['name'], row[
            'class']
        print(f"{layer_number}. name {layer_name} class {layer_class}")

        if layer_class not in get_classes:  # skip layers/classes not in list
            continue
        else:
            print('getting layer')
            converted_to_2d = False  # set to True if 4d acts have been converted to 2d
            model = loaded_model
            layer_name = layer_name
            gha_key_layers.append(layer_name)

            # model to record hid acts
            intermediate_layer_model = Model(
                inputs=model.input, outputs=model.get_layer(layer_name).output)
            intermediate_output = intermediate_layer_model.predict(gha_items,
                                                                   verbose=1)
            layer_acts_shape = np.shape(intermediate_output)

            if save_2d_layers:
                if len(layer_acts_shape) == 2:
                    acts_2d = intermediate_output

                elif len(layer_acts_shape) == 4:  # # call mean_act_conv
                    acts_2d = kernel_to_2d(intermediate_output, verbose=True)
                    layer_acts_shape = np.shape(acts_2d)
                    converted_to_2d = True

                else:
                    print(
                        "\n\n\n\nSHAPE ERROR - UNEXPECTED DIMENSIONS\n\n\n\n")
                    acts_2d = 'SHAPE_ERROR'
                    layer_acts_shape = 'NONE'

                hid_act_2d_dict[index] = {
                    'layer_name': layer_name,
                    'layer_class': layer_class,
                    "layer_shape": layer_acts_shape,
                    '2d_acts': acts_2d
                }

                if converted_to_2d:
                    hid_act_2d_dict[index]['converted_to_2d'] = True

                print(f"\nlayer{index}. hid_act_2d_dict: {layer_acts_shape}\n")

                # # save distplot for sanity check
                sns.distplot(np.ravel(acts_2d))
                plt.title(str(layer_name))
                # plt.savefig(f"layer_act_dist/{output_filename}_{layer_name}_layer_act_distplot.png")
                plt.savefig(
                    f"{output_filename}_{layer_name}_layer_act_distplot.png")

                plt.close()

    print("\n**** saving info to summary page and dictionary ****")

    hid_act_filenames = {'2d': None, 'any_d': None}
    if save_2d_layers:
        dict_2d_save_name = f'{output_filename}_hid_act_2d.pickle'
        with open(dict_2d_save_name,
                  "wb") as pkl:  # 'wb' mean 'w'rite the file in 'b'inary mode
            pickle.dump(hid_act_2d_dict, pkl)
        # np.save(dict_2d_save_name, hid_act_2d_dict)
        hid_act_filenames['2d'] = dict_2d_save_name

    if save_4d_layers:
        dict_4dsave_name = f'{output_filename}_hid_act_any_d.pickle'
        with open(dict_4dsave_name,
                  "wb") as pkl:  # 'wb' mean 'w'rite the file in 'b'inary mode
            pickle.dump(hid_act_any_d_dict, pkl)
        # np.save(dict_4dsave_name, hid_act_any_d_dict)
        hid_act_filenames['any_d'] = dict_4dsave_name

    cond = sim_dict["topic_info"]["cond"]
    run = sim_dict["topic_info"]["run"]

    hid_units = sim_dict['model_info']['layers']['hid_layers']['hid_totals'][
        'analysable']

    trained_for = sim_dict["training_info"]["trained_for"]
    end_accuracy = sim_dict["training_info"]["acc"]
    dataset = sim_dict["data_info"]["dataset"]
    gha_date = int(datetime.datetime.now().strftime("%y%m%d"))
    gha_time = int(datetime.datetime.now().strftime("%H%M"))

    gha_acc = scores_dict['gha_acc']
    n_cats_correct = scores_dict['n_cats_correct']

    # # GHA_info_dict
    gha_dict_name = f"{output_filename}_GHA_dict.pickle"
    gha_dict_path = os.path.join(gha_path, gha_dict_name)

    gha_dict = {
        "topic_info": sim_dict['topic_info'],
        "data_info": sim_dict['data_info'],
        "model_info": sim_dict['model_info'],
        "training_info": sim_dict['training_info'],
        "GHA_info": {
            "use_dataset": use_dataset,
            'x_data_path': x_data_path,
            'y_data_path': y_data_path,
            'gha_path': gha_path,
            'gha_dict_path': gha_dict_path,
            "gha_incorrect": gha_incorrect,
            "hid_act_files": hid_act_filenames,
            'gha_key_layers': gha_key_layers,
            'key_n_units_fils': key_n_units_fils,
            "gha_date": gha_date,
            "gha_time": gha_time,
            "scores_dict": scores_dict,
            "model_dict": model_dict
        }
    }

    # pickle_out = open(gha_dict_name, "wb")
    # pickle.dump(gha_dict, pickle_out)
    # pickle_out.close()

    with open(gha_dict_name, "wb") as pickle_out:
        pickle.dump(gha_dict, pickle_out)

    if verbose:
        focussed_dict_print(gha_dict, 'gha_dict', ['GHA_info', "scores_dict"])

    # make a list of dict names to do sel on
    if not os.path.isfile(f"{output_filename}_dict_list_for_sel.csv"):
        dict_list = open(f"{output_filename}_dict_list_for_sel.csv", 'w')
        mywriter = csv.writer(dict_list)
    else:
        dict_list = open(f"{output_filename}_dict_list_for_sel.csv", 'a')
        mywriter = csv.writer(dict_list)

    mywriter.writerow([gha_dict_name[:-7]])
    dict_list.close()

    print(f"\nadded to list for selectivity analysis: {gha_dict_name[:-7]}")

    # # # spare variables to make anaysis easier
    # if 'chanProp' in output_filename:
    #     var_one = 'chanProp'
    # elif 'chanDist' in output_filename:
    #     var_one = 'chanDist'
    # elif 'cont' in output_filename:
    #     var_one = 'cont'
    # elif 'bin' in output_filename:
    #     var_one = 'bin'
    # else:
    #     raise ValueError("dset_type not found (v1)")
    #
    # if 'pro_sm' in output_filename:
    #     var_two = 'pro_sm'
    # elif 'pro_med' in output_filename:
    #     var_two = 'pro_med'
    # # elif 'LB' in output_filename:
    # #     var_two = 'LB'
    # else:
    #     raise ValueError("between not found (v2)")
    #
    # if 'v1' in output_filename:
    #     var_three = 'v1'
    # elif 'v2' in output_filename:
    #     var_three = 'v2'
    # elif 'v3' in output_filename:
    #     var_three = 'v3'
    # else:
    #     raise ValueError("within not found (v3)")
    #
    # var_four = var_two + var_three
    #
    # if 'ReLu' in output_filename:
    #     var_five = 'relu'
    # elif 'relu' in output_filename:
    #     var_five = 'relu'
    # elif 'sigm' in output_filename:
    #     var_five = 'sigm'
    # else:
    #     raise ValueError("act_func not found (v4)")
    #
    # if '10' in output_filename:
    #     var_six = 10
    # elif '25' in output_filename:
    #     var_six = 25
    # elif '50' in output_filename:
    #     var_six = 50
    # elif '100' in output_filename:
    #     var_six = 100
    # elif '500' in output_filename:
    #     var_six = 500
    # else:
    #     raise ValueError("hid_units not found in output_filename (var6)")

    # print(f"\n{output_filename}: {var_one} {var_two} {var_three} {var_four} {var_five} {var_six}")

    gha_info = [
        cond,
        run,
        output_filename,
        n_layers,
        hid_units,
        dataset,
        use_dataset,
        gha_incorrect,
        n_cats,
        trained_for,
        end_accuracy,
        gha_acc,
        n_cats_correct,
        # var_one, var_two, var_three, var_four, var_five, var_six
    ]

    # # check if gha_summary.csv exists
    # # save summary file in exp folder (grandparent dir to gha folder: exp/cond/gha)
    # to move up to parent just use '..' rather than '../..'

    # exp_name = exp_dir.strip('/')
    exp_name = sim_dict['topic_info']['exp_name']

    os.chdir('../..')
    exp_path = os.getcwd()

    if not os.path.isfile(exp_name + "_GHA_summary.csv"):
        gha_summary = open(exp_name + "_GHA_summary.csv", 'w')
        mywriter = csv.writer(gha_summary)
        summary_headers = [
            "cond",
            "run",
            'filename',
            "n_layers",
            "hid_units",
            "dataset",
            "GHA_on",
            'incorrect',
            "n_cats",
            "trained_for",
            "train_acc",
            "gha_acc",
            'n_cats_correct',
            # 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'
        ]

        mywriter.writerow(summary_headers)
        print(f"creating summary csv at: {exp_path}")

    else:
        gha_summary = open(exp_name + "_GHA_summary.csv", 'a')
        mywriter = csv.writer(gha_summary)
        print(f"appending to summary csv at: {exp_path}")

    mywriter.writerow(gha_info)
    gha_summary.close()

    print("\nend of ff_gha")

    return gha_info, gha_dict
Exemple #9
0
def loop_thru_acts(gha_dict_path,
                   correct_items_only=True,
                   acts_saved_as='pickle',
                   letter_sel=False,
                   already_completed={},
                   verbose=False, test_run=False):
    """To use hidden unit activations for sel, (lesioning?) visualisation.
        1. load dict from study (GHA dict) - get variables from dict
    2. load y, sort out incorrect resonses
    3. find where to load gha from: pickle, hdf5, shelve.
        for now write assuming pickle
        load hidden activations
    4. loop through all layers:
        loop through all units:
            (loop through timesteps?)

    This is a generator-iterator, not a function, as I want it to yeild hid_acts
    and details one unit at a time.

    :param gha_dict_path: path of the gha dict
    :param correct_items_only: Whether to skip test items that that model got incorrect.
    :param letter_sel: if False, test sel for words (class-labels).
            If True, test for letters (parts) using 'local_word_X' for each word when looping through classes
    :param already_completed: None, or dict with layer_names as keys,
                            values are ether 'all' or number of last completed unit.
    :param acts_saved_as: file format used to save gha

    :param verbose: how much to print to screen
    :param test_run: if True, only do subset, e.g., 3 units from 3 layers

    :return: hid_acts_array (activation values for all items at this unit (all timesteps?))
    :return: data_details (item_numbers, class_labels, whether_correct)
    :return: unit-details (layer_name, layer_class, act_func, unit_number)
    """

    if verbose:
        print("\n**** running loop_thru_units() ****")

    # # check already completed dict
    if already_completed is not None:
        if type(already_completed) is not dict:
            TypeError("already-completed should be a dict")
        else:
            for value in already_completed.values():
                if value is not 'all':
                    if type(value) is not int:
                        ValueError("already-completed dict values should be int of last completed unit or 'all'")

    # # part 1. load dict from study (should run with sim, GHA or sel dict)
    gha_dict = load_dict(gha_dict_path)
    focussed_dict_print(gha_dict, 'gha_dict')

    # # use gha-dict_path to get exp_cond_gha_path, gha_dict_name,
    exp_cond_gha_path = gha_dict['GHA_info']['gha_path']
    # gha_dict_name = gha_dict['GHA_info']['hid_act_files']['2d']
    os.chdir(exp_cond_gha_path)
    current_wd = os.getcwd()

    # get topic_info from dict
    output_filename = gha_dict["topic_info"]["output_filename"]
    if letter_sel:
        output_filename = f"{output_filename}_lett"

    if verbose:
        print(f"\ncurrent_wd: {current_wd}")
        print(f"output_filename: {output_filename}")


    # # get model info from dict
    units_per_layer = gha_dict['model_info']["overview"]["units_per_layer"]

    if 'n_layers' in gha_dict['model_info']['overview']:
        n_layers = gha_dict['model_info']['overview']['n_layers']
    elif 'hid_layers' in gha_dict['model_info']['overview']:
        n_layers = gha_dict['model_info']['overview']['hid_layers']
    elif 'act_layers' in gha_dict['model_info']['layers']['hid_layers']['hid_totals']:
        n_layers = gha_dict['model_info']['layers']['hid_layers']['hid_totals']['act_layers']
    else:
        raise ValueError("How many layers? ln 690 Network")
    model_dict = gha_dict['model_info']['config']
    if verbose:
        focussed_dict_print(model_dict, 'model_dict')



    # # check for sequences/rnn
    sequence_data = False
    y_1hot = True

    if 'timesteps' in gha_dict['model_info']['overview']:
        timesteps = gha_dict['model_info']["overview"]["timesteps"]
        if timesteps > 1:
            sequence_data = True
            if 'serial_recall' in gha_dict['model_info']['overview']:
                serial_recall = gha_dict['model_info']["overview"]["serial_recall"]
                y_1hot = serial_recall
            if 'y_1hot' in gha_dict['model_info']['overview']:
                y_1hot = gha_dict['model_info']["overview"]["y_1hot"]


    # # I can't do class correlations for letters, (as it is the equivillent of
    # having a dist output for letters
    if letter_sel:
        y_1hot = False

    # # get gha info from dict
    hid_acts_filename = gha_dict["GHA_info"]["hid_act_files"]['2d']


    '''Part 2 - load y, sort out incorrect resonses'''
    print("\n\nPart 2: loading labels")
    # # load y_labels to go with hid_acts and item_correct for sequences
    if 'seq_corr_list' in gha_dict['GHA_info']['scores_dict']:
        n_seqs = gha_dict['GHA_info']['scores_dict']['n_seqs']
        n_seq_corr = gha_dict['GHA_info']['scores_dict']['n_seq_corr']
        n_incorrect = n_seqs - n_seq_corr

        test_label_seq_name = gha_dict['GHA_info']['y_data_path']
        seqs_corr = gha_dict['GHA_info']['scores_dict']['seq_corr_list']

        test_label_seqs = np.load(f"{test_label_seq_name}labels.npy")

        if verbose:
            print(f"test_label_seqs: {np.shape(test_label_seqs)}")
            print(f"seqs_corr: {np.shape(seqs_corr)}")
            print(f"n_seq_corr: {n_seq_corr}")

        """get 1hot item vectors for 'words' and 3 hot for letters"""
        # '''Always use serial_recall True. as I want a separate 1hot vector for each item.
        # Always use x_data_type 'local_letter_X' as I want 3hot vectors'''
        # y_letters = []
        # y_words = []
        # for this_seq in test_label_seqs:
        #     get_letters, get_words = get_X_and_Y_data_from_seq(vocab_dict=vocab_dict,
        #                                                        seq_line=this_seq,
        #                                                        serial_recall=True,
        #                                                        end_seq_cue=False,
        #                                                        x_data_type='local_letter_X')
        #     y_letters.append(get_letters)
        #     y_words.append(get_words)
        #
        # y_letters = np.array(y_letters)
        # y_words = np.array(y_words)
        # if verbose:
        #     print(f"\ny_letters: {type(y_letters)}  {np.shape(y_letters)}")
        #     print(f"\ny_words: {type(y_words)}  {np.shape(y_words)}")
        #     print(f"\ntest_label_seqs[0]: {test_label_seqs[0]}")
        #     if test_run:
        #         print(f"y_letters[0]:\n{y_letters[0]}")
        #         print(f"y_words[0]:\n{y_words[0]}")

        y_df_headers = [f"ts{i}" for i in range(timesteps)]
        y_scores_df = pd.DataFrame(data=test_label_seqs, columns=y_df_headers)
        y_scores_df['full_model'] = seqs_corr
        if verbose:
            print(f"\ny_scores_df: {y_scores_df.shape}\n{y_scores_df.head()}")



    # # if not sequence data, load y_labels to go with hid_acts and item_correct for items
    elif 'item_correct_name' in gha_dict['GHA_info']['scores_dict']:
        n_correct = gha_dict['GHA_info']['scores_dict']['n_correct']
        n_incorrect = gha_dict['GHA_info']['scores_dict']['n_items'] - n_correct
        n_seq_corr = n_correct
        # # load item_correct (y_data)
        item_correct_name = gha_dict['GHA_info']['scores_dict']['item_correct_name']
        # y_df = pd.read_csv(item_correct_name)
        y_scores_df = nick_read_csv(item_correct_name)
        seqs_corr = y_scores_df['full_model'].to_list()
        test_label_seqs = np.array(y_scores_df['full_model'].to_list())

        if verbose:
            print(f"\nVERBOSE\n\n"
                  f"y_scores_df:\n{y_scores_df.head()}\n\n"
                  f"seqs_corr:\n{seqs_corr}\n"
                  f"test_label_seqs:\n{test_label_seqs}\n"
                  f"")


    # else:




    """# # sort incorrect item data"""
    print("\n\nRemoving incorrect responses")
    # # # get values for correct/incorrect items (1/0 or True/False)
    item_correct_list = y_scores_df['full_model'].tolist()
    full_model_values = list(set(item_correct_list))

    correct_symbol = 1
    if len(full_model_values) != 2:
        TypeError(f"TYPE_ERROR!: what are the scores/acc for items? {full_model_values}")
    if 1 not in full_model_values:
        if True in full_model_values:
            correct_symbol = True
        else:
            TypeError(f"TYPE_ERROR!: what are the scores/acc for items? {full_model_values}")

    # # i need to check whether this analysis should include incorrect items
    gha_incorrect = gha_dict['GHA_info']['gha_incorrect']

    # # get item indeces for correct and incorrect items
    item_index = list(range(n_seq_corr))

    incorrect_items = []
    correct_items = []
    for index in range(len(item_correct_list)):
        if item_correct_list[index] == 0:
            incorrect_items.append(index)
        else:
            correct_items.append(index)
    if correct_items_only:
        item_index == correct_items

    else:
        '''tbh I'm not quite sure what the implications of this are, 
        just a hack to make it work for untrained model'''
        print("\n\nWARNING\nitem_index == what_shape\njust doing this for untrained model!")

        what_shape = list(range(960))
        print(f'what_shape: {np.shape(what_shape)}\n{what_shape}\n')
        item_index = what_shape
        print(f'item_index: {np.shape(item_index)}\n{item_index}\n')

    print(f"incorrect_items: {np.shape(incorrect_items)}\n{incorrect_items}")
    print(f'item_index: {np.shape(item_index)}\n{item_index}\n')
    print(f'correct_items_only: {correct_items_only}')

    if gha_incorrect:
        if correct_items_only:
            if verbose:
                print("\ngha_incorrect: True (I have incorrect responses)\n"
                      "correct_items_only: True (I only want correct responses)")
                print(f"remove {n_incorrect} incorrect from hid_acts & output using y_scores_df.")
                print("use y_correct for y_df")

            y_correct_df = y_scores_df.loc[y_scores_df['full_model'] == correct_symbol]
            y_df = y_correct_df

            mask = np.ones(shape=len(seqs_corr), dtype=bool)
            mask[incorrect_items] = False
            test_label_seqs = test_label_seqs[mask]

        else:
            if verbose:
                print("\ngha_incorrect: True (I have incorrect responses)\n"
                      "correct_items_only: False (I want incorrect responses)")
                print("no changes needed - don't remove anything from hid_acts, output and "
                      "use y scores as y_df")

                y_df = y_scores_df
    else:
        if correct_items_only:
            if verbose:
                print("\ngha_incorrect: False (I only have correct responses)\n"
                      "correct_items_only: True (I only want correct responses)")
                print("no changes needed - don't remove anything from hid_acts or output.  "
                      "Use y_correct as y_df")
            y_correct_df = y_scores_df.loc[y_scores_df['full_model'] == correct_symbol]
            y_df = y_correct_df
        else:
            if verbose:
                print("\ngha_incorrect: False (I only have correct responses)\n"
                      "correct_items_only: False (I want incorrect responses)")
                raise TypeError("I can not complete this as desried"
                                "change correct_items_only to True"
                                "for analysis  - don't remove anything from hid_acts, output and "
                                "use y scores as y_df")


    if verbose is True:
        print(f"\ny_df: {y_df.shape}\n{y_df.head()}")
        print(f"\ntest_label_seqs: {np.shape(test_label_seqs)}")  # \n{test_label_seqs}")



    # # Part 3 - where to load hid_acts from
    if acts_saved_as is 'pickle':
        with open(hid_acts_filename, 'rb') as pkl:
            hid_acts_dict = pickle.load(pkl)

        hid_acts_keys_list = list(hid_acts_dict.keys())

        if verbose:
            print(f"\n**** pickle opening {hid_acts_filename} ****")
            print(f"hid_acts_keys_list: {hid_acts_keys_list}")
            print(f"first layer keys: {list(hid_acts_dict[0].keys())}")
            # focussed_dict_print(hid_acts_dict, 'hid_acts_dict')

        last_hid_act_number = hid_acts_keys_list[-1]
        last_layer_name = hid_acts_dict[last_hid_act_number]['layer_name']
        
    elif acts_saved_as is 'h5':
        with h5py.File(hid_acts_filename, 'r') as hid_acts_dict:
            hid_acts_keys_list = list(hid_acts_dict.keys())

            if verbose:
                print(f"\n**** h5py opening {hid_acts_filename} ****")
                print(f"hid_acts_keys_list: {hid_acts_keys_list}")

            last_hid_act_number = hid_acts_keys_list[-1]
            last_layer_name = hid_acts_dict[last_hid_act_number]['layer_name']



    '''part 4 loop through layers and units'''
    # # loop through dict/layers
    if verbose:
        print("\n*** looping through layers ***")
    if test_run:
        layer_counter = 0

    layer_number = -1

    # # loop through layer numbers in list of dict keys
    for hid_act_number in hid_acts_keys_list:
        layer_number = layer_number + 1

        # # don't run sel on output layer
        if hid_act_number == last_hid_act_number:
            if verbose:
                print(f"\nskip output layer! (layer: {hid_act_number})")
            continue

        if test_run is True:
            layer_counter = layer_counter + 1
            if layer_counter > 3:
                if verbose:
                    print(f"\tskip this layer!: test_run, only running subset of layers")
                continue


        '''could add something to check which layers/units have been done 
        already and start from there?'''

        # # Once I've decided to run this unit
        if acts_saved_as is 'pickle':
            with open(hid_acts_filename, 'rb') as pkl:
                hid_acts_dict = pickle.load(pkl)
                layer_dict = hid_acts_dict[hid_act_number]
                
        elif acts_saved_as is 'h5':
            with h5py.File(hid_acts_filename, 'r') as hid_acts_dict:
                layer_dict = hid_acts_dict[hid_act_number]


        layer_name = layer_dict['layer_name']
        found_name = gha_dict['model_info']['layers']['hid_layers'][layer_number]['name']
        print(layer_name, found_name)
        if found_name == layer_name:
            print('match\n')
        else:
            print("wrong\n")
            while found_name != layer_name:
                # print("still wrong")
                found_name = gha_dict['model_info']['layers']['hid_layers'][layer_number]['name']
                if found_name == layer_name:
                    print(layer_name, found_name, 'match\n')
                    break
                layer_number += 1

        act_func = gha_dict['model_info']['layers']['hid_layers'][layer_number]['act_func']
        print(f"\nlayer_name: {layer_name}\n"
              f"found_name: {found_name}\n"
              f"layer_number: {layer_number}\n"
              f"act_func: {act_func}\n")


        partially_completed_layer = False
        if layer_name in already_completed:
            if already_completed[layer_name] is 'all':
                print("already completed analysis on this layer")
                continue
            else:
                partially_completed_layer = True

        if verbose:
            print(f"\nrunning layer {hid_act_number}: {layer_name}")

        if 'hid_acts' in layer_dict:
            hid_acts_array = layer_dict['hid_acts']
        elif '2d_acts' in layer_dict:
            hid_acts_array = layer_dict['2d_acts']
        else:
            print(f"how are hid_acts labelled in layer_dict?: {layer_dict.keys()}")


        if verbose:
            if sequence_data:
                print(f"np.shape(hid_acts_array) (n_seqs, timesteps, units_per_layer): "
                      f"{np.shape(hid_acts_array)}")
            else:
                print(f"np.shape(hid_acts_array) (n_items, units_per_layer): "
                      f"{np.shape(hid_acts_array)}")




        # # remove incorrect responses from np array
        if correct_items_only:
            if gha_incorrect:
                if verbose:
                    print(f"\nremoving {n_incorrect} incorrect responses from "
                          f"hid_acts_array: {np.shape(hid_acts_array)}")

                hid_acts_array = hid_acts_array[mask]
                # these_labels = np.array(list(range(_n_items)))[mask]
                if verbose:
                    print(f"(cleaned) np.shape(hid_acts_array) (n_seqs_corr, timesteps, units_per_layer): "
                          f"{np.shape(hid_acts_array)}"
                          f"\ntest_label_seqs: {np.shape(test_label_seqs)}")

        # get units per layer
        if len(np.shape(hid_acts_array)) == 3:
            _n_seqs, _timesteps, units_per_layer = np.shape(hid_acts_array)
        elif len(np.shape(hid_acts_array)) == 2:
            _n_items, units_per_layer = np.shape(hid_acts_array)
        else:
            raise ValueError(f"hid_acts array should be 2d or 3d, not {np.shape(hid_acts_array)}")


        
        
        '''loop through units'''
        if verbose:
            print("\n**** loop through units ****")

        unit_counter = 0
        # for unit_index, unit in enumerate(hid_acts_df.columns):
        for unit_index in range(units_per_layer):

            if partially_completed_layer:
                if unit_index <= already_completed[layer_name]:
                    print("already run this unit")
                    continue

            if test_run is True:
                unit_counter += 1
                if unit_counter > 3:
                    continue

            print(f"\n****\nrunning layer {hid_act_number} of {n_layers} "
                  f"({layer_name}): unit {unit_index} of {units_per_layer}\n****")
            
            if sequence_data:
                
                one_unit_all_timesteps = hid_acts_array[:, :, unit_index]

                if np.sum(one_unit_all_timesteps) == 0:
                    dead_unit = True
                    if verbose:
                        print("dead unit")
                else:
                    if verbose:
                        print(f"\nnp.shape(one_unit_all_timesteps) (seqs, timesteps): "
                              f"{np.shape(one_unit_all_timesteps)}")
                    
                    # get hid acts for each timestep
                    for timestep in range(timesteps):
                        print("\n\tunit {} timestep {} (of {})".format(unit_index, timestep, timesteps))

                        one_unit_one_timestep = one_unit_all_timesteps[:, timestep]

                        # y_labels_one_timestep_float = combo_data[:, timestep]
                        y_labels_one_timestep_float = test_label_seqs[:, timestep]
                        y_labels_one_timestep = [int(q) for q in y_labels_one_timestep_float]


                        these_acts = one_unit_one_timestep
                        these_labels = y_labels_one_timestep

                        if verbose:
                            print(f'item_index: {np.shape(item_index)}')
                            print(f'these_acts: {np.shape(these_acts)}')
                            print(f'these_labels: {np.shape(these_labels)}')

                        # insert act values in middle of labels (item, act, cat)
                        item_act_label_array = np.vstack((item_index, these_acts, these_labels)).T

                        if verbose:
                            print(f"\t - one_unit_one_timestep shape: (n seqs) {np.shape(one_unit_one_timestep)}")
                            print(f"\t - y_labels_one_timestep shape: {np.shape(y_labels_one_timestep)}")

                            print(f"\t - item_act_label_array shape: (item_idx, hid_acts, y_label) "
                                  f"{np.shape(item_act_label_array)}")
                        # print(f"\titem_act_label_array: \n\t{item_act_label_array}")

                        loop_dict = {
                                     "sequence_data": sequence_data,
                                     "y_1hot": y_1hot,
                                     "act_func": act_func,
                                     "hid_act_number": hid_act_number, "layer_name": layer_name,
                                     "unit_index": unit_index,
                                     "timestep": timestep,
                                     'item_act_label_array': item_act_label_array,
                                     }

                        yield loop_dict

                    # return hid act, data info, unit info and timestep
            else:
                # if not sequences, just items
                this_unit_just_acts = hid_acts_array[:, unit_index]
                
                if np.sum(this_unit_just_acts) == 0:
                    dead_unit = True
                    if verbose:
                        print("dead unit")
                else:
                    if verbose:
                        print(f"\nnp.shape(this_unit_just_acts) (items, ): "
                              f"{np.shape(this_unit_just_acts)}")

                    these_acts = this_unit_just_acts
                    these_labels = y_df['class'].tolist()
                    #
                    # if verbose:
                    #     print(f"\nidiot check\n"
                    #           f"item_index: {np.shape(item_index)}\n"
                    #           f"these_acts: {np.shape(these_acts)}\n"
                    #           f"these_labels: {np.shape(these_labels)}\n"
                    #           f"these_labels: {these_labels}")

                    # insert act values in middle of labels (item, act, cat)
                    item_act_label_array = np.vstack((item_index, these_acts, these_labels)).T

                    if verbose:
                        print(" - item_act_label_array shape: {}".format(np.shape(item_act_label_array)))
                        print(f"item_act_label_array: (item_idx, hid_acts, y_label)\n{item_act_label_array}")

                    timestep = None

                    loop_dict = {
                        "sequence_data": sequence_data,
                        "y_1hot": y_1hot,
                        "act_func": act_func,
                        "hid_act_number": hid_act_number, "layer_name": layer_name,
                        "unit_index": unit_index,
                        "timestep": timestep,
                        'item_act_label_array': item_act_label_array,
                    }

                    yield loop_dict
Exemple #10
0
def rnn_gha(sim_dict_path,
            gha_incorrect=True,
            use_dataset='train_set',
            get_layer_list=None,
            exp_root='/home/nm13850/Documents/PhD/python_v2/experiments/',
            verbose=False,
            test_run=False):
    """
    gets activations from hidden units.

    1. load simulation dict (with data info) (*_load_dict.pickle)
        sim_dict can be fed in from sim script, or loaded separately
    2. load model - get structure and details
    3. run dataset through once, recording accuracy per item/class
    4. run on 2nd model to get hid acts

    :param sim_dict_path: path to the dictionary for this experiment condition
    :param gha_incorrect: GHA for ALL items (True) or just correct items (False)
    :param use_dataset: GHA for train/test data
    :param get_layer_list: if None, gha all layers, else list of layer names to gha
    :param exp_root: root to save experiments
    :param verbose:
    :param test_run: Set test = True to just do one unit per layer

    :return: dict with hid acts per layer.  saved as dict so different shaped arrays don't matter too much
    """

    print('**** ff_gha GHA() ****')

    # # # PART 1 # # #
    # # load details from dict
    if type(sim_dict_path) is str:
        if os.path.isfile(sim_dict_path):
            print(f"sim_dict_path: {sim_dict_path}")
            sim_dict = load_dict(sim_dict_path)
            full_exp_cond_path, sim_dict_name = os.path.split(sim_dict_path)

        elif os.path.isfile(os.path.join(exp_root, sim_dict_path)):
            sim_dict_path = os.path.join(exp_root, sim_dict_path)
            print(f"sim_dict_path: {sim_dict_path}")
            sim_dict = load_dict(sim_dict_path)
            full_exp_cond_path, sim_dict_name = os.path.split(sim_dict_path)

    elif type(sim_dict_path) is dict:
        sim_dict = sim_dict_path
        sim_dict_path = sim_dict['training_info']['sim_dict_path']
        full_exp_cond_path = sim_dict['topic_info']['exp_cond_path']

    else:
        raise FileNotFoundError(sim_dict_path)

    os.chdir(full_exp_cond_path)
    print(f"set_path to full_exp_cond_path: {full_exp_cond_path}")

    focussed_dict_print(sim_dict, 'sim_dict')

    # # # load datasets
    data_dict = sim_dict['data_info']
    if use_dataset is 'generator':
        vocab_dict = load_dict(
            os.path.join(data_dict["data_path"], data_dict["vocab_dict"]))
        n_cats = data_dict["n_cats"]
        x_data_path = sim_dict['training_info']['x_data_path']
        # y_data_path = sim_dict['training_info']['y_data_path']
        # n_items = 'unknown'

    else:
        # load data from somewhere
        # n_items = data_dict["n_items"]
        n_cats = data_dict["n_cats"]
        hdf5_path = sim_dict['topic_info']["dataset_path"]

        x_data_path = hdf5_path
        # y_data_path = '/home/nm13850/Documents/PhD/python_v2/datasets/' \
        #               'objects/ILSVRC2012/imagenet_hdf5/y_df.csv'

        seq_data = pd.read_csv(data_dict["seqs"],
                               header=None,
                               names=['seq1', 'seq2', 'seq3'])
        print(f"\nseq_data: {seq_data.shape}\n{seq_data.head()}")

        x_data = np.load(data_dict["x_data"])
        print("\nshape of x_data: {}".format(np.shape(x_data)))

        y_labels = np.loadtxt(data_dict["y_labels"],
                              delimiter=',').astype('int8')
        print(f"\ny_labels:\n{y_labels}")
        print(np.shape(y_labels))

        y_data = to_categorical(y_labels, num_classes=30)
        print(f"\ny_data:\n{y_data}")
        print(np.shape(y_data))

    # # # data preprocessing
    # # # if network is cnn but data is 2d (e.g., MNIST)
    # if len(np.shape(x_data)) != 4:
    #     if sim_dict['model_info']['overview']['model_type'] == 'cnn':
    #         width, height = sim_dict['data_info']['image_dim']
    #         x_data = x_data.reshape(x_data.shape[0], width, height, 1)
    #         print(f"\nRESHAPING x_data to: {np.shape(x_data)}")

    # # other details
    # hid_units = sim_dict['model_info']['layers']['hid_layers']['hid_totals']["analysable"]
    optimizer = sim_dict['model_info']["overview"]["optimizer"]
    loss_func = sim_dict['model_info']["overview"]["loss_func"]
    batch_size = sim_dict['model_info']["overview"]["batch_size"]
    timesteps = sim_dict['model_info']["overview"]["timesteps"]
    serial_recall = sim_dict['model_info']["overview"]["serial_recall"]
    x_data_type = sim_dict['model_info']["overview"]["x_data_type"]
    end_seq_cue = sim_dict['model_info']["overview"]["end_seq_cue"]
    act_func = sim_dict['model_info']["overview"]["act_func"]

    # input_dim = data_dict["X_size"]
    # output_dim = data_dict["n_cats"]

    # Output files
    output_filename = sim_dict["topic_info"]["output_filename"]
    print(f"\nOutput file: {output_filename}")

    # # # # PART 2 # # #
    print("\n**** THE MODEL ****")
    model_name = sim_dict['model_info']['overview']['trained_model']

    if os.path.isfile(model_name):
        loaded_model = load_model(model_name)
    else:
        training_dir, sim_dict_name = os.path.split(sim_dict_path)
        print(f"training_dir: {training_dir}\n"
              f"sim_dict_name: {sim_dict_name}")
        if os.path.isfile(os.path.join(training_dir, model_name)):
            loaded_model = load_model(os.path.join(training_dir, model_name))

    loaded_model.trainable = False

    model_details = loaded_model.get_config()
    # print_nested_round_floats(model_details)
    focussed_dict_print(model_details, 'model_details')

    n_layers = len(model_details['layers'])
    model_dict = dict()

    # # turn off "trainable" and get useful info

    for layer in range(n_layers):
        # set to not train
        # model_details['layers'][layer]['config']['trainable'] = 'False'

        if verbose:
            print(f"Model layer {layer}: {model_details['layers'][layer]}")

        # # get useful info
        layer_dict = {
            'layer': layer,
            'name': model_details['layers'][layer]['config']['name'],
            'class': model_details['layers'][layer]['class_name']
        }

        if 'units' in model_details['layers'][layer]['config']:
            layer_dict['units'] = model_details['layers'][layer]['config'][
                'units']
        if 'activation' in model_details['layers'][layer]['config']:
            layer_dict['act_func'] = model_details['layers'][layer]['config'][
                'activation']
        if 'filters' in model_details['layers'][layer]['config']:
            layer_dict['filters'] = model_details['layers'][layer]['config'][
                'filters']
        if 'kernel_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'kernel_size'][0]
        if 'pool_size' in model_details['layers'][layer]['config']:
            layer_dict['size'] = model_details['layers'][layer]['config'][
                'pool_size'][0]
        if 'strides' in model_details['layers'][layer]['config']:
            layer_dict['strides'] = model_details['layers'][layer]['config'][
                'strides'][0]
        if 'rate' in model_details['layers'][layer]['config']:
            layer_dict["rate"] = model_details['layers'][layer]['config'][
                'rate']

        # # set and save layer details
        model_dict[layer] = layer_dict

    # # my model summary
    model_df = pd.DataFrame.from_dict(
        data=model_dict,
        orient='index',
        columns=[
            'layer', 'name', 'class', 'act_func', 'units', 'filters', 'size',
            'strides', 'rate'
        ],
    )

    print(f"\nmodel_df\n{model_df}")

    # # make new df with just layers of interest
    if get_layer_list is None:
        key_layers_df = model_df
        get_layer_list = key_layers_df['name'].tolist()

    key_layers_df = model_df.loc[model_df['name'].isin(get_layer_list)]

    key_layers_df.reset_index(inplace=True)
    del key_layers_df['index']
    key_layers_df.index.name = 'index'
    key_layers_df = key_layers_df.drop(columns=['size', 'strides', 'rate'])

    # # add column ('n_units_filts')to say how many things needs gha per layer (number of units or filters)
    # # add zeros to rows with no units or filters
    key_layers_df.loc[:, 'n_units_filts'] = key_layers_df.units.fillna(
        0) + key_layers_df.filters.fillna(0)

    # print(f"\nkey_layers_df:\n{key_layers_df}")

    key_layers_df.loc[:,
                      "n_units_filts"] = key_layers_df["n_units_filts"].astype(
                          int)

    # # get to total number of units or filters in key layers of the network
    key_n_units_fils = sum(key_layers_df['n_units_filts'])

    print(f"\nkey_layers_df:\n{key_layers_df.head()}")
    print(f"key_n_units_fils: {key_n_units_fils}")
    '''i currently get output layer, make sure I keep this in to make sure I can do class correlation'''

    # # # set dir to save gha stuff # # #
    hid_act_items = 'all'
    if not gha_incorrect:
        hid_act_items = 'correct'
    gha_folder = f'{hid_act_items}_{use_dataset}_gha'
    if test_run:
        gha_folder = os.path.join(gha_folder, 'test')

    cond_name = sim_dict['topic_info']['output_filename']
    condition_path = find_path_to_dir(long_path=full_exp_cond_path,
                                      target_dir=cond_name)
    gha_path = os.path.join(condition_path, gha_folder)

    if not os.path.exists(gha_path):
        os.makedirs(gha_path)
    os.chdir(gha_path)
    print(f"\nsaving hid_acts to: {gha_path}")

    # # # get hid acts for each timestep even if output is free-recall
    # print("\nchanging layer attribute: return_sequnces")
    # for layer in loaded_model.layers:
    #     # set to return sequences = True
    #     # model_details['layers'][layer]['config']['return_sequences'] = True
    #     if hasattr(layer, 'return_sequences'):
    #         layer.return_sequences = True
    #         print(layer.name, layer.return_sequences)
    #
    #     if verbose:
    #         print(f"Model layer {layer}: {model_details['layers'][layer]}")

    # # sort optimizers
    # # I don't think the choice of optimizer should actually mater since I am not training.
    sgd = SGD(momentum=.9)  # decay=sgd_lr / max_epochs)
    this_optimizer = sgd

    if optimizer == 'SGD_no_momentum':
        this_optimizer = SGD(momentum=0.0,
                             nesterov=False)  # decay=sgd_lr / max_epochs)
    elif optimizer == 'SGD_Nesterov':
        this_optimizer = SGD(momentum=.1,
                             nesterov=True)  # decay=sgd_lr / max_epochs)
    elif optimizer == 'SGD_mom_clip':
        this_optimizer = SGD(momentum=.9,
                             clipnorm=1.)  # decay=sgd_lr / max_epochs)
    elif optimizer == 'dougs':
        print("I haven't added the code for doug's momentum to GHA script yet")
        this_optimizer = None
        # this_optimizer = dougsMomentum(momentum=.9)

    elif optimizer == 'adam':
        this_optimizer = Adam(amsgrad=False)
    elif optimizer == 'adam_amsgrad':
        # simulations run prior to 05122019 did not have this option, and may have use amsgrad under the name 'adam'
        this_optimizer = Adam(amsgrad=True)

    elif optimizer == 'RMSprop':
        this_optimizer = RMSprop()
    elif optimizer == 'Adagrad':
        this_optimizer = Adagrad()
    elif optimizer == 'Adadelta':
        this_optimizer = Adadelta()
    elif optimizer == 'Adamax':
        this_optimizer = Adamax()
    elif optimizer == 'Nadam':
        this_optimizer = Nadam()

    # # # PART 3 get_scores() # # #
    loaded_model.compile(loss=loss_func,
                         optimizer=this_optimizer,
                         metrics=['accuracy'])

    # # load test_seqs if they are there, else generate some
    data_path = sim_dict['data_info']['data_path']

    if not os.path.exists(data_path):
        if os.path.exists(switch_home_dirs(data_path)):
            data_path = switch_home_dirs(data_path)
        else:
            raise FileExistsError(f'data path not found: {data_path}')

    print(f'data_path: {data_path}')

    # test_filename = f'seq{timesteps}_v{n_cats}_960_test_seq_labels.npy'
    test_filename = f'seq{timesteps}_v{n_cats}_1per_ts_test_seq_labels.npy'
    test_seq_path = os.path.join(data_path, test_filename)
    test_label_seqs = np.load(test_seq_path)

    print(f'test_label_seqs: {np.shape(test_label_seqs)}\n{test_label_seqs}\n')

    test_label_name = os.path.join(data_path, test_filename[:-10])

    seq_words_df = pd.read_csv(f"{test_label_name}words.csv")

    # test_IPC_name = os.path.join(data_path, f"seq{timesteps}_v{n_cats}_960_test_IPC.pickle")
    test_IPC_name = os.path.join(
        data_path, f"seq{timesteps}_v{n_cats}_1per_ts_test_IPC.pickle")

    IPC_dict = load_dict(test_IPC_name)

    # else:
    #
    #     n_seqs = 30*batch_size
    #
    #     test_label_seqs = get_label_seqs(n_labels=n_cats, seq_len=timesteps,
    #                                      repetitions=serial_recall, n_seqs=n_seqs)
    #     test_label_name = f"{output_filename}_{np.shape(test_label_seqs)[0]}_test_seq_"
    #
    #
    #     # print(f"test_label_name: {test_label_name}")
    #     np.save(f"{test_label_name}labels.npy", test_label_seqs)
    #
    #     seq_words_df = spell_label_seqs(test_label_seqs=test_label_seqs,
    #                                     test_label_name=f"{test_label_name}words.csv",
    #                                     vocab_dict=vocab_dict, save_csv=True)
    if verbose:
        print(seq_words_df.head())

    scores_dict = get_test_scores(
        model=loaded_model,
        data_dict=data_dict,
        test_label_seqs=test_label_seqs,
        serial_recall=serial_recall,
        x_data_type=x_data_type,
        # output_type=output_type,
        end_seq_cue=end_seq_cue,
        batch_size=batch_size,
        verbose=verbose)

    mean_IoU = scores_dict['mean_IoU']
    prop_seq_corr = scores_dict['prop_seq_corr']

    # IPC_dict = seq_items_per_class(label_seqs=test_label_seqs, vocab_dict=vocab_dict)
    # test_IPC_name = f"{output_filename}_{n_seqs}_test_IPC.pickle"
    # with open(test_IPC_name, "wb") as pickle_out:
    #     pickle.dump(IPC_dict, pickle_out, protocol=pickle.HIGHEST_PROTOCOL)

    # # PART 5
    print("\n**** Get Hidden unit activations ****")
    hid_acts_dict = dict()

    # # loop through key layers df
    gha_key_layers = []
    for index, row in key_layers_df.iterrows():
        if test_run:
            if index > 3:
                continue

        layer_number, layer_name, layer_class = row['layer'], row['name'], row[
            'class']
        print(f"\n{layer_number}. name: {layer_name}; class: {layer_class}")

        # if layer_class not in get_classes:  # no longer using this - skip class types not in list
        if layer_name not in get_layer_list:  # skip layers/classes not in list
            continue

        else:
            # record hid acts
            layer_activations = get_layer_acts(model=loaded_model,
                                               layer_name=layer_name,
                                               data_dict=data_dict,
                                               test_label_seqs=test_label_seqs,
                                               serial_recall=serial_recall,
                                               end_seq_cue=end_seq_cue,
                                               batch_size=batch_size,
                                               verbose=verbose)

            layer_acts_shape = np.shape(layer_activations)

            print(f"\nlen(layer_acts_shape): {len(layer_acts_shape)}")

            converted_to_2d = False  # set to True if 4d acts have been converted to 2d
            if len(layer_acts_shape) == 2:
                hid_acts = layer_activations

            elif len(layer_acts_shape) == 3:
                # if not serial_recall:
                #     ValueError(f"layer_acts_shape: {layer_acts_shape}"
                #                f"\n3d expected only for serial recall")
                # else:
                hid_acts = layer_activations

            # elif len(layer_acts_shape) == 4:  # # call mean_act_conv
            #     hid_acts = kernel_to_2d(layer_activations, verbose=True)
            #     layer_acts_shape = np.shape(hid_acts)
            #     converted_to_2d = True

            else:
                ValueError(
                    f"Unexpected number of dimensions for layer activations {layer_acts_shape}"
                )

            hid_acts_dict[index] = {
                'layer_name': layer_name,
                'layer_class': layer_class,
                "layer_shape": layer_acts_shape,
                'hid_acts': hid_acts
            }

            if converted_to_2d:
                hid_acts_dict[index]['converted_to_2d'] = True

            print(f"\nlayer {index}. layer_acts_shape: {layer_acts_shape}\n")

            # # save distplot for sanity check
            sns.distplot(np.ravel(hid_acts))
            plt.title(str(layer_name))
            plt.savefig(f"{layer_name}_act_distplot.png")
            plt.close()

        print("\n**** saving info to summary page and dictionary ****")

        hid_act_filenames = {'2d': None, 'any_d': None}
        dict_2d_save_name = f'{output_filename}_hid_act.pickle'
        with open(dict_2d_save_name,
                  "wb") as pkl:  # 'wb' mean 'w'rite the file in 'b'inary mode
            pickle.dump(hid_acts_dict, pkl)
        # np.save(dict_2d_save_name, hid_acts_dict)
        hid_act_filenames['2d'] = dict_2d_save_name

    cond = sim_dict["topic_info"]["cond"]
    run = sim_dict["topic_info"]["run"]
    if test_run:
        run = 'test'

    hid_units = sim_dict['model_info']['layers']['hid_layers']['hid_totals'][
        'analysable']

    trained_for = sim_dict["training_info"]["trained_for"]
    end_accuracy = sim_dict["training_info"]["acc"]
    dataset = sim_dict["data_info"]["dataset"]
    gha_date = int(datetime.datetime.now().strftime("%y%m%d"))
    gha_time = int(datetime.datetime.now().strftime("%H%M"))

    # gha_acc = scores_dict['gha_acc']
    # n_cats_correct = scores_dict['n_cats_correct']

    # # GHA_info_dict
    gha_dict_name = f"{output_filename}_GHA_dict.pickle"
    gha_dict_path = os.path.join(gha_path, gha_dict_name)

    gha_dict = {
        "topic_info": sim_dict['topic_info'],
        "data_info": sim_dict['data_info'],
        "model_info": sim_dict['model_info'],
        "training_info": sim_dict['training_info'],
        "GHA_info": {
            "use_dataset": use_dataset,
            'x_data_path': x_data_path,
            'y_data_path': test_label_name,
            'IPC_dict_path': test_IPC_name,
            'gha_path': gha_path,
            'gha_dict_path': gha_dict_path,
            "gha_incorrect": gha_incorrect,
            "hid_act_files": hid_act_filenames,
            'gha_key_layers': gha_key_layers,
            'key_n_units_fils': key_n_units_fils,
            "gha_date": gha_date,
            "gha_time": gha_time,
            "scores_dict": scores_dict,
        }
    }

    with open(gha_dict_name, "wb") as pickle_out:
        pickle.dump(gha_dict, pickle_out)

    if verbose:
        focussed_dict_print(gha_dict, 'gha_dict', ['GHA_info'])

    gha_info = [
        cond, run, output_filename,
        sim_dict['model_info']['overview']['model_name'], n_layers, hid_units,
        dataset, use_dataset, gha_incorrect, n_cats, timesteps, x_data_type,
        act_func, serial_recall, trained_for, end_accuracy, mean_IoU,
        prop_seq_corr, test_run, gha_date, gha_time
    ]

    # # check if gha_summary.csv exists

    # # save sel summary in exp folder not condition folder
    exp_name = sim_dict['topic_info']['exp_name']
    exp_path = find_path_to_dir(long_path=gha_path, target_dir=exp_name)
    os.chdir(exp_path)

    if not os.path.isfile(exp_name + "_GHA_summary.csv"):
        gha_summary = open(exp_name + "_GHA_summary.csv", 'w')
        mywriter = csv.writer(gha_summary)
        summary_headers = [
            "cond", "run", 'filename', 'model', "n_layers", "hid_units",
            "dataset", "GHA_on", 'incorrect', "n_cats", "timesteps",
            "x_data_type", "act_func", "serial_recall", "trained_for",
            "train_acc", "mean_IoU", "prop_seq_corr", "test_run", "gha_date",
            "gha_time"
        ]

        mywriter.writerow(summary_headers)
        print(f"creating summary csv at: {exp_path}")

    else:
        gha_summary = open(exp_name + "_GHA_summary.csv", 'a')
        mywriter = csv.writer(gha_summary)
        print(f"appending to summary csv at: {exp_path}")

    mywriter.writerow(gha_info)
    gha_summary.close()

    # make a list of dict names to do sel on
    if not os.path.isfile(f"{exp_name}_dict_list_for_sel.csv"):
        dict_list = open(f"{exp_name}_dict_list_for_sel.csv", 'w')
        mywriter = csv.writer(dict_list)
    else:
        dict_list = open(f"{exp_name}_dict_list_for_sel.csv", 'a')
        mywriter = csv.writer(dict_list)

    mywriter.writerow([gha_dict_name[:-7]])
    dict_list.close()

    print(f"\nadded to list for selectivity analysis: {gha_dict_name[:-7]}")

    print("\nend of ff_gha")

    return gha_dict
Exemple #11
0
def train_model(
    exp_name,
    data_dict_path,
    model_path,
    cond_name=None,
    cond=None,
    run=None,
    max_epochs=100,
    use_optimizer='adam',
    loss_target=0.01,
    min_loss_change=0.0001,
    batch_size=32,
    lr=0.001,
    n_layers=1,
    units_per_layer=200,
    act_func='relu',
    use_bias=True,
    y_1hot=True,
    output_act='softmax',
    weight_init='GlorotUniform',
    augmentation=True,
    grey_image=False,
    use_batch_norm=False,
    use_dropout=0.0,
    use_val_data=True,
    timesteps=1,
    exp_root='/home/nm13850/Documents/PhD/python_v2/experiments/',
    verbose=False,
    test_run=False,
):
    """
    script to train a neural network on a task


    1. get details - dset_name, model, other key variables
    2. load datasets
    3. compile model
    4. fit/evaluate model
    5. make plots
    6. output:
        training plots
        model accuracy etc.

    this model will stop training if
    a. loss does not improve for [patience] epochs by [min_loss_change]
    b. accuracy reaches 100%
    c. when max_epochs is reached.


    :param exp_name: name of this experiemnt or set of model/dataset pairs
    :param data_dict_path: path to data dict
    :param model_path: dir for model
    :param cond_name: name of this condition
    :param cond: number for this condition
    :param run: number for this run (e.g., multiple runs of same conditin from different initializations)
    :param max_epochs: Stop training after this many epochs
    :param use_optimizer: Optimizer to use
    :param loss_target: stop training when this target is reached
    :param min_loss_change: stop training of loss does not improve by this much
    :param batch_size: number of items loaded in at once
    :param augmentation: whether data aug is used (for images)
    :param grey_image: whether the images are grey (if false, they are colour)
    :param use_batch_norm: use batch normalization
    :param use_dropout: use dropout
    :param use_val_data: use validation set (either separate set, or train/val split)
    :param timesteps: if RNN length of sequence
    :param exp_root: root directory for saving experiments

    :param verbose: if 0, not verbose; if 1 - print basics; if 2, print all

    :return: training_info csv
    :return: sim_dict with dataset info, model info and training info

    """

    dset_dir, data_dict_name = os.path.split(data_dict_path)
    dset_dir, dset_name = os.path.split(dset_dir)
    model_dir, model_name = os.path.split(model_path)

    print(f"dset_dir: {dset_dir}\ndset_name: {dset_name}")
    print(f"model_dir: {model_dir}\nmodel_name: {model_name}")

    # Output files
    if not cond_name:
        output_filename = f"{exp_name}_{model_name}_{dset_name}"
    else:
        output_filename = f"{exp_name}_{cond_name}"

    print(f"\noutput_filename: {output_filename}")

    print(data_dict_path)
    # # get info from dict
    if os.path.isfile(data_dict_path):
        data_dict = load_dict(data_dict_path)
    elif os.path.isfile(
            os.path.join('/home/nm13850/Documents/PhD/python_v2/datasets/',
                         data_dict_path)):
        data_dict_path = os.path.join(
            '/home/nm13850/Documents/PhD/python_v2/datasets/', data_dict_path)
        data_dict = load_dict(data_dict_path)
    elif os.path.isfile(switch_home_dirs(data_dict_path)):
        data_dict_path = switch_home_dirs(data_dict_path)
        data_dict = load_dict(data_dict_path)
    else:
        raise FileNotFoundError(data_dict_path)

    if verbose:
        # # print basic details
        print("\n**** STUDY DETAILS ****")
        print(
            f"output_filename: {output_filename}\ndset_name: {dset_name}\nmodel: {model_name}\n"
            f"max_epochs: {max_epochs}\nuse_optimizer: {use_optimizer}\n"
            f"loss_target: {loss_target}\nmin_loss_change: {min_loss_change}\n"
            f"batch_norm: {use_batch_norm}\nval_data: {use_val_data}\naugemntation: {augmentation}\n"
        )
        focussed_dict_print(data_dict, 'data_dict')

    # # check for training data
    if 'train_set' in data_dict:
        if os.path.isfile(
                os.path.join(data_dict['data_path'],
                             data_dict['train_set']['X_data'])):
            x_data_path = os.path.join(data_dict['data_path'],
                                       data_dict['train_set']['X_data'])
            y_data_path = os.path.join(data_dict['data_path'],
                                       data_dict['train_set']['Y_labels'])
        elif os.path.isfile(
                switch_home_dirs(
                    os.path.join(data_dict['data_path'],
                                 data_dict['train_set']['X_data']))):
            x_data_path = switch_home_dirs(
                os.path.join(data_dict['data_path'],
                             data_dict['train_set']['X_data']))
            y_data_path = switch_home_dirs(
                os.path.join(data_dict['data_path'],
                             data_dict['train_set']['Y_labels']))
        else:
            raise FileNotFoundError(
                f"training data not found\n"
                f"{os.path.join(data_dict['data_path'], data_dict['train_set']['X_data'])}"
            )
    else:
        # # if no training set
        if os.path.isfile(
                os.path.join(data_dict['data_path'], data_dict['X_data'])):
            x_data_path = os.path.join(data_dict['data_path'],
                                       data_dict['X_data'])
            y_data_path = os.path.join(data_dict['data_path'],
                                       data_dict['Y_labels'])
        else:
            data_path = switch_home_dirs(data_dict['data_path'])
            if os.path.isfile(os.path.join(data_path, data_dict['X_data'])):
                x_data_path = os.path.join(data_path, data_dict['X_data'])
                y_data_path = os.path.join(data_path, data_dict['Y_labels'])
                data_dict['data_path'] = data_path
            else:
                raise FileNotFoundError(
                    f'cant find x data at\n'
                    f'{os.path.join(data_path, data_dict["X_data"])}')

    x_data = load_x_data(x_data_path)
    y_df, y_label_list = load_y_data(y_data_path)

    n_cats = data_dict["n_cats"]
    y_data = to_categorical(y_label_list, num_classes=n_cats)

    # # val data
    if use_val_data:
        print("\n**** Loading validation data ****")
        if "val_set" in data_dict:
            x_val = load_x_data(
                os.path.join(data_dict['data_path'],
                             data_dict['val_set']['X_data']))
            y_val_df, y_val_label_list = load_y_data(
                os.path.join(data_dict['data_path'],
                             data_dict['val_set']['Y_labels']))
            y_val = to_categorical(y_val_label_list, num_classes=n_cats)

            # # do I need these?
            # x_train = x_data
            # y_train = y_data
        else:
            print("validation data not found - performing split")
            x_train, x_val, y_train_label_list, y_val_label_list = train_test_split(
                x_data, y_label_list, test_size=0.2, random_state=1)
            print(
                f"y_train_label_list: {np.shape(y_train_label_list)}.  "
                f"Count: {np.unique(y_train_label_list, return_counts=True)[1]}\n"
                f"y_val_label_list: {np.shape(y_val_label_list)}.  "
                f"count {np.unique(y_val_label_list, return_counts=True)[1]}")
            y_train = to_categorical(y_train_label_list, num_classes=n_cats)
            y_val = to_categorical(y_val_label_list, num_classes=n_cats)
    else:
        x_train = x_data
        y_train = y_data

    n_items = len(y_train)

    # fix random seed for reproducability during development - not for simulations
    if test_run:
        seed = 7
        np.random.seed(seed)

    # # data preprocessing
    x_size = data_dict["X_size"]

    # check data shape
    print(f"\ninput shape: {np.shape(x_train)}")
    if len(np.shape(x_train)) == 4:
        image_dim = data_dict['image_dim']
        n_items, width, height, channels = np.shape(x_train)
    else:
        # # this is just for MNIST
        if model_dir in ['cnn', 'cnns']:
            print("reshaping mnist for cnn")
            width, height = data_dict['image_dim']
            x_train = x_train.reshape(x_train.shape[0], width, height, 1)
            x_data = x_data.reshape(x_data.shape[0], width, height, 1)

            print(f"\nRESHAPING x_train to: {np.shape(x_train)}")
            if use_val_data:
                x_val = x_val.reshape(x_val.shape[0], width, height, 1)
        if model_dir == 'mlps':
            if len(np.shape(x_train)) > 2:
                print(
                    f"reshaping image data from {len(np.shape(x_train))}d to 2d for mlp"
                )
                x_train = np.reshape(
                    x_train,
                    (x_train.shape[0], x_train.shape[1] * x_train.shape[2]))
                x_data = np.reshape(
                    x_data,
                    (x_data.shape[0], x_data.shape[1] * x_data.shape[2]))

                if use_val_data:
                    # x_val = x_val.transpose(2,0,1).reshape(3,-1)
                    x_val = np.reshape(
                        x_val,
                        (x_val.shape[0], x_val.shape[1] * x_val.shape[2]))

                print(f"\nNEW input shape: {np.shape(x_train)}")
                x_size = np.shape(x_train)[1]
                print(f"NEW x_size: {x_size}")

    print(f"x_data.dtype: {x_data.dtype}")
    print(f"x_train.dtype: {x_train.dtype}")

    # Preprocess the data (these are NumPy arrays)
    if x_train.dtype == "uint8":
        print(f"converting input data from {x_train.dtype} to float32")
        max_x = np.amax(x_data)
        x_train = x_train.astype("float32") / max_x
        x_data = x_data.astype("float32") / max_x
        # x_test = x_test.astype("float32") / 255
        # y_train = y_train.astype("float32")
        # y_test = y_test.astype("float32")
        if use_val_data:
            x_val = x_val.astype("float32") / max_x
        print(f"x_data.dtype: {x_data.dtype}")
        print(f"x_train.dtype: {x_train.dtype}")
    """# # # proprocessing here - put data in range 0-1
    # if len(np.shape(x_train)) == 2:
    #     scaler = MinMaxScaler()
    #     x_train = scaler.fit_transform(x_train)
    #     if use_val_data is True:
    #         x_val = scaler.fit_transform(x_val)"""

    # # save path
    exp_cond_path = os.path.join(exp_root, exp_name, output_filename)
    if test_run:
        exp_cond_path = os.path.join(exp_cond_path, 'test')

    if not os.path.exists(exp_cond_path):
        os.makedirs(exp_cond_path)
    os.chdir(exp_cond_path)
    print(f"\nsaving to: {exp_cond_path}")

    # # The Model
    if model_dir in ['mlp', 'mlps']:
        print("\nloading an mlp model")
        augmentation = False
        model_dict = {'mlp': mlp, 'fc1': fc1, 'fc2': fc2, 'fc4': fc4}

        model = model_dict[model_name].build(features=x_size,
                                             classes=n_cats,
                                             n_layers=n_layers,
                                             units_per_layer=units_per_layer,
                                             act_func=act_func,
                                             use_bias=use_bias,
                                             y_1hot=y_1hot,
                                             output_act=output_act,
                                             batch_size=batch_size,
                                             weight_init=weight_init,
                                             batch_norm=use_batch_norm,
                                             dropout=use_dropout)
        # augmentation = False
        # units_per_layer = 32
        # # models[model_name].build(...)
        # elif model_name == 'fc4':
        #     build_model = fc4.build(classes=n_cats, units_per_layer=units_per_layer,
        #                             batch_norm=use_batch_norm, dropout=use_dropout)
        # elif model_name == 'fc2':
        #     build_model = fc2.build(classes=n_cats, units_per_layer=units_per_layer,
        #                             batch_norm=use_batch_norm, dropout=use_dropout)
        # elif model_name == 'fc1':
        #     build_model = fc1.build(classes=n_cats, units_per_layer=units_per_layer,
        #                             batch_norm=use_batch_norm, dropout=use_dropout)

    elif model_dir in ['cnn', 'cnns']:
        print("loading a cnn model")

        model_dict = {
            'con6_pool3_fc1': con6_pool3_fc1,
            'con4_pool2_fc1': con4_pool2_fc1,
            'con2_pool2_fc1': con2_pool2_fc1,
            'con4_pool2_fc1_reluconv': con4_pool2_fc1_reluconv,
            'con4_pool2_fc1_noise_layer': con4_pool2_fc1_noise_layer,
            'con2_pool2_fc1_reluconv': con2_pool2_fc1_reluconv,
            'conv1_pool1_fc1_reluconv': conv1_pool1_fc1_reluconv
        }

        units_per_layer = None
        width, height = data_dict['image_dim']
        depth = 3
        if grey_image:
            depth = 1

        model = model_dict[model_name].build(width=width,
                                             height=height,
                                             depth=depth,
                                             classes=n_cats,
                                             batch_norm=use_batch_norm,
                                             dropout=use_dropout)

    elif 'rnn' in model_dir:
        print(
            "\n\nERROR! trying to load an rnn model. probably use train_STM_RNN for that"
        )

        # print("loading a recurrent model")
        # augmentation = False
        # model_dict = {'Bowers14rnn': Bowers14rnn,
        #               'SimpleRNNn': SimpleRNNn,
        #               'GRUn': GRUn,
        #               'LSTMn': LSTMn,
        #               'Seq2Seq': Seq2Seq}
        #
        #
        # model = model_dict[model_name].build(features=x_size, classes=n_cats, timesteps=timesteps,
        #                                      batch_size=batch_size, n_layers=n_layers,
        #                                      serial_recall=serial_recall,
        #                                      units_per_layer=units_per_layer, act_func=act_func,
        #                                      y_1hot=serial_recall,
        #                                      dropout=use_dropout)
    else:
        print("model_dir not recognised")

    # model = build_model

    # loss
    loss_func = 'categorical_crossentropy'
    if n_cats == 2:
        loss_func = 'binary_crossentropy'
    if not y_1hot:
        loss_func = 'binary_crossentropy'

    # optimizer
    if use_optimizer in ['sgd', 'SGD']:
        this_optimizer = SGD(lr=lr)
    elif use_optimizer in ['sgd_decay', 'SGD_decay']:
        this_optimizer = SGD(lr=lr, decay=lr / max_epochs)
    elif use_optimizer == 'adam':
        this_optimizer = Adam(lr=lr)
    elif use_optimizer == 'rmsprop':
        this_optimizer = RMSprop(lr=lr, decay=1e-6)
    else:
        raise ValueError(f'use_optimizer not recognized: {use_optimizer}')

    # # compile model
    model.compile(loss=loss_func, optimizer=this_optimizer, metrics=['acc'])

    # # get model dict
    model_info = get_model_dict(model)  # , verbose=True)
    # print("\nmodel_info:")
    print_nested_round_floats(model_info, 'model_info')
    tf.compat.v1.keras.utils.plot_model(model,
                                        to_file=f'{model_name}_diag.png',
                                        show_shapes=True)

    # # call backs and training parameters
    checkpoint_path = f'{output_filename}_model.hdf5'

    checkpoint_mon = 'loss'
    if use_val_data:
        checkpoint_mon = 'val_loss'

    # checkpointing.  Save model and weights with best val loss.
    checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                      monitor=checkpoint_mon,
                                                      verbose=1,
                                                      save_best_only=True,
                                                      save_weights_only=False,
                                                      mode='auto')

    # patience_for_loss_change: wait this long to see if loss improves
    patience_for_loss_change = int(max_epochs / 50)
    if patience_for_loss_change < 5:
        patience_for_loss_change = 5

    # early_stop_plateau - if there is no imporovement
    early_stop_plateau = tf.keras.callbacks.EarlyStopping(
        monitor='loss',
        min_delta=min_loss_change,
        patience=patience_for_loss_change,
        verbose=1,
        mode='min')

    val_early_stop_plateau = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        min_delta=min_loss_change,
        patience=patience_for_loss_change,
        verbose=verbose,
        mode='min')

    # # # early stop acc
    # # # should stop when acc reaches 1.0 (e.g., will not carry on training)
    # early_stop_acc = tf.keras.callbacks.EarlyStopping(monitor='acc', baseline=1.0, patience=0)
    #
    # val_early_stop_acc = tf.keras.callbacks.EarlyStopping(monitor='val_acc', baseline=1.0, patience=0)

    date_n_time = int(datetime.datetime.now().strftime("%Y%m%d%H%M"))
    tensorboard_path = os.path.join(exp_cond_path, 'tb', str(date_n_time))

    tensorboard = TensorBoard(
        log_dir=tensorboard_path,
        # histogram_freq=1,
        # batch_size=batch_size,
        # write_graph=True,
        # # write_grads=False,
        # # write_images=False,
        # # embeddings_freq=0,
        # # embeddings_layer_names=None,
        # # embeddings_metadata=None,
        # # embeddings_data=None,
        # update_freq='epoch',
        # profile_batch=2
    )

    print('\n\nto access tensorboard, in terminal use\n'
          f'tensorboard --logdir={tensorboard_path}'
          '\nthen click link'
          '')

    callbacks_list = [early_stop_plateau, checkpointer, tensorboard]
    val_callbacks_list = [val_early_stop_plateau, checkpointer, tensorboard]

    ############################
    # # train model
    print("\n**** TRAINING ****")
    if augmentation:
        # # construct the image generator for data augmentation
        aug = ImageDataGenerator(
            rotation_range=
            10,  # randomly rotate images in the range (degrees, 0 to 180)
            width_shift_range=
            0.1,  # randomly shift images horizontally (fraction of total width)
            height_shift_range=
            0.1,  # randomly shift images vertically (fraction of total height)
            shear_range=0.1,  # set range for random shear (tilt image)
            zoom_range=0.1,  # set range for random zoom
            # horizontal_flip=True,
            fill_mode="nearest")

        if use_val_data:
            fit_model = model.fit_generator(
                aug.flow(x_train, y_train, batch_size=batch_size),
                validation_data=(x_val, y_val),
                # steps_per_epoch=len(x_train) // batch_size,
                epochs=max_epochs,
                verbose=1,
                callbacks=val_callbacks_list)
        else:
            fit_model = model.fit_generator(
                aug.flow(x_train, y_train, batch_size=batch_size),
                # steps_per_epoch=len(x_train) // batch_size,
                epochs=max_epochs,
                verbose=1,
                callbacks=callbacks_list)

    else:
        if use_val_data:
            fit_model = model.fit(x_train,
                                  y_train,
                                  validation_data=(x_val, y_val),
                                  epochs=max_epochs,
                                  batch_size=batch_size,
                                  verbose=1,
                                  callbacks=val_callbacks_list)
        else:
            fit_model = model.fit(x_train,
                                  y_train,
                                  epochs=max_epochs,
                                  batch_size=batch_size,
                                  verbose=1,
                                  callbacks=callbacks_list)

    ############################################

    print("\n**** TRAINING COMPLETE ****")
    # # plot the training loss and accuracy
    fig, (ax1, ax2) = plt.subplots(2, sharex=True)
    ax1.plot(fit_model.history['acc'])
    if use_val_data:
        ax1.plot(fit_model.history['val_acc'])
    ax1.set_title('model accuracy (top); loss (bottom)')
    ax1.set_ylabel('accuracy')
    ax1.set_xlabel('epoch')
    ax2.plot(fit_model.history['loss'])
    if use_val_data:
        ax2.plot(fit_model.history['val_loss'])
    ax2.set_ylabel('loss')
    ax2.set_xlabel('epoch')
    fig.legend(['train', 'val'], loc='upper left')
    plt.savefig(str(output_filename) + '_training.png')
    plt.close()

    # # Training info
    print(f"Model name: {checkpoint_path}")

    # # get best epoch
    if use_val_data:
        print(fit_model.history['val_loss'])
        trained_for = int(fit_model.history['val_loss'].index(
            min(fit_model.history['val_loss'])))
        end_val_loss = float(fit_model.history['val_loss'][trained_for])
        end_val_acc = float(fit_model.history['val_acc'][trained_for])
    else:
        print(fit_model.history['loss'])
        trained_for = int(fit_model.history['loss'].index(
            min(fit_model.history['loss'])))
        end_val_loss = np.nan
        end_val_acc = np.nan

    end_loss = float(fit_model.history['loss'][trained_for])
    end_acc = float(fit_model.history['acc'][trained_for])
    print(f'\nTraining Info\nbest loss after {trained_for} epochs\n'
          f'end loss: {end_loss}\nend acc: {end_acc}\n'
          f'end val loss: {end_val_loss}\nend val acc: {end_val_acc}')

    # # # PART 3 get_scores() # # #
    # # these three lines are to re-shape MNIST
    print(f"len(np.shape(x_data)): {len(np.shape(x_data))}")
    if len(np.shape(x_data)) != 4:
        if model_dir in ['cnn', 'cnns']:
            x_data = x_data.reshape(x_data.shape[0], width, height, 1)
    print(f"len(np.shape(x_data)): {len(np.shape(x_data))}")
    print(f"{type(x_data)}")
    print(f"{x_data.dtype}")

    predicted_outputs = model.predict(
        x_data)  # use x_data NOT x_train to fit shape of y_df
    item_correct_df, scores_dict, incorrect_items = get_scores(
        predicted_outputs,
        y_df,
        output_filename,
        y_1hot=y_1hot,
        verbose=True,
        save_all_csvs=True)

    if verbose:
        focussed_dict_print(scores_dict, 'Scores_dict')

    trained_date = int(datetime.datetime.now().strftime("%y%m%d"))
    trained_time = int(datetime.datetime.now().strftime("%H%M"))
    model_info['overview'] = {
        'model_type': model_dir,
        'model_name': model_name,
        "trained_model": checkpoint_path,
        "n_layers": n_layers,
        "units_per_layer": units_per_layer,
        "act_func": act_func,
        "optimizer": use_optimizer,
        "use_bias": use_bias,
        "weight_init": weight_init,
        "y_1hot": y_1hot,
        "output_act": output_act,
        "lr": lr,
        "max_epochs": max_epochs,
        "loss_func": loss_func,
        "batch_size": batch_size,
        "use_batch_norm": use_batch_norm,
        "use_dropout": use_dropout,
        "use_val_data": use_val_data,
        "augmentation": augmentation,
        "grey_image": grey_image,
        "loss_target": loss_target,
        "min_loss_change": min_loss_change,
        'timesteps': timesteps
    }

    git_repository = '/home/nm13850/Documents/PhD/code/library'
    if os.path.isdir('/Users/nickmartin/Documents/PhD/code/library'):
        git_repository = '/Users/nickmartin/Documents/PhD/code/library'

    repo = git.Repo(git_repository)

    sim_dict_name = f"{output_filename}_sim_dict.txt"

    # # simulation_info_dict
    sim_dict = {
        "topic_info": {
            "output_filename": output_filename,
            "cond": cond,
            "run": run,
            "data_dict_path": data_dict_path,
            "model_path": model_path,
            "exp_cond_path": exp_cond_path,
            'exp_name': exp_name,
            'cond_name': cond_name
        },
        "data_info": data_dict,
        "model_info": model_info,
        'scores': scores_dict,
        "training_info": {
            "sim_dict_name": sim_dict_name,
            "trained_for": trained_for,
            "loss": end_loss,
            "acc": end_acc,
            'use_val_data': use_val_data,
            "end_val_acc": end_val_acc,
            "end_val_loss": end_val_loss,
            "trained_date": trained_date,
            "trained_time": trained_time,
            'x_data_path': x_data_path,
            'y_data_path': y_data_path,
            'tensorboard_path': tensorboard_path,
            'commit': repo.head.object.hexsha,
        }
    }

    focussed_dict_print(sim_dict, 'sim_dict')

    if not use_val_data:
        sim_dict['training_info']['end_val_acc'] = 'NaN'
        sim_dict['training_info']['end_val_loss'] = 'NaN'

    with open(sim_dict_name, 'w') as fp:
        json.dump(sim_dict, fp, indent=4, separators=(',', ':'))
    """converts lists of units per layer [32, 64, 128] to str "32-64-128".
    Convert these strings back to lists of ints with:
    back_to_ints = [int(i) for i in str_upl.split(sep='-')]
    """
    str_upl = "-".join(
        map(str, model_info['layers']['hid_layers']['hid_totals']['UPL']))
    str_fpl = "-".join(
        map(str, model_info['layers']['hid_layers']['hid_totals']['FPL']))

    # # # spare variables to make anaysis easier
    # if 'chanProp' in cond_name:
    #     var_one = 'chanProp'
    # elif 'chanDist' in cond_name:
    #     var_one = 'chanDist'
    # elif 'cont' in cond_name:
    #     var_one = 'cont'
    # elif 'bin' in cond_name:
    #     var_one = 'bin'
    # else:
    #     raise ValueError("dset_type not found (v1)")
    #
    # if 'pro_sm' in cond_name:
    #     var_two = 'pro_sm'
    # elif 'pro_med' in cond_name:
    #     var_two = 'pro_med'
    # # elif 'LB' in cond_name:
    # #     var_two = 'LB'
    # else:
    #     raise ValueError("between not found (v2)")
    #
    # if 'v1' in cond_name:
    #     var_three = 'v1'
    # elif 'v2' in cond_name:
    #     var_three = 'v2'
    # elif 'v3' in cond_name:
    #     var_three = 'v3'
    # else:
    #     raise ValueError("within not found (v3)")
    #
    # var_four = var_two + var_three
    #
    # if 'ReLu' in cond_name:
    #     var_five = 'relu'
    # elif 'relu' in cond_name:
    #     var_five = 'relu'
    # elif 'sigm' in cond_name:
    #     var_five = 'sigm'
    # else:
    #     raise ValueError("act_func not found (v4)")
    #
    # if '10' in cond_name:
    #     var_six = 10
    # elif '25' in cond_name:
    #     var_six = 25
    # elif '50' in cond_name:
    #     var_six = 50
    # elif '100' in cond_name:
    #     var_six = 100
    # elif '500' in cond_name:
    #     var_six = 500
    # else:
    #     raise ValueError("hid_units not found in cond_name (var6)")

    # print(f"\n{cond_name}: {var_one} {var_two} {var_three} {var_four} {var_five} {var_six}")

    # record training info comparrisons
    training_info = [
        output_filename,
        cond,
        run,
        dset_name,
        x_size,
        n_cats,
        timesteps,
        n_items,
        model_dir,
        model_name,
        act_func,
        model_info['layers']['totals']['all_layers'],
        model_info['layers']['totals']['hid_layers'],
        model_info['layers']['hid_layers']['hid_totals']['act_layers'],
        model_info['layers']['hid_layers']['hid_totals']['dense_layers'],
        str_upl,
        model_info['layers']['hid_layers']['hid_totals']['conv_layers'],
        str_fpl,
        model_info['layers']['hid_layers']['hid_totals']['analysable'],
        use_optimizer,
        use_batch_norm,
        use_dropout,
        batch_size,
        augmentation,
        grey_image,
        use_val_data,
        loss_target,
        min_loss_change,
        max_epochs,
        trained_for,
        end_acc,
        end_loss,
        end_val_acc,
        end_val_loss,
        checkpoint_path,
        trained_date,
        trained_time,
        # var_one, var_two, var_three, var_four, var_five, var_six
    ]

    exp_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
    os.chdir(exp_path)
    print(f"save_summaries: {exp_path}")

    # check if training_info.csv exists
    if not os.path.isfile(f"{exp_name}_training_summary.csv"):

        headers = [
            "file",
            "cond",
            "run",
            "dataset",
            "x_size",
            "n_cats",
            'timesteps',
            "n_items",
            "model_type",
            "model",
            "act_func",
            "all_layers",
            'hid_layers',
            "act_layers",
            "dense_layers",
            "UPL",
            "conv_layers",
            "FPL",
            "analysable",
            "optimizer",
            "batch_norm",
            "dropout",
            "batch_size",
            "aug",
            "grey_image",
            "val_data",
            "loss_target",
            "min_loss_change",
            "max_epochs",
            "trained_for",
            "end_acc",
            "end_loss",
            "end_val_acc",
            "end_val_loss",
            "model_file",
            "date",
            "time",
            # 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'
        ]

        training_overview = open(f"{exp_name}_training_summary.csv", 'w')
        mywriter = csv.writer(training_overview)
        mywriter.writerow(headers)
    else:
        training_overview = open(f"{exp_name}_training_summary.csv", 'a')
        mywriter = csv.writer(training_overview)

    mywriter.writerow(training_info)
    training_overview.close()

    if verbose:
        focussed_dict_print(sim_dict, 'sim_dict')

    print('\n\nto access tensorboard, in terminal use\n'
          f'tensorboard --logdir={tensorboard_path}'
          '\nthen click link')

    print("\nff_sim finished")

    return training_info, sim_dict