Пример #1
0
def main(save_fn, gpu_id=None):
    """ Run supervised learning training """

    # Update all dependencies in parameters
    update_dependencies()

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # If desired, train the convolutional layers with the CIFAR datasets
    # Otherwise, the network will load convolutional weights from the saved file
    if (par['task'] in ['cifar', 'imagenet', 'colored_mnist'
                        ]) and par['train_convolutional_layers']:
        convolutional_layers.ConvolutionalLayers()

    print('\nRunning model.\n')

    # Reset TensorFlow graph
    tf.reset_default_graph()

    # Create placeholders for the model
    if par['task'] == 'mnist':
        x = tf.placeholder(tf.float32,
                           [par['batch_size'], par['layer_dims'][0]], 'stim')
    elif par['task'] == 'colored_mnist':
        x = tf.placeholder(tf.float32, [par['batch_size'], 32, 32, 3], 'stim')
    elif par['task'] == 'cifar' or par['task'] == 'imagenet':
        x = tf.placeholder(tf.float32, [par['batch_size'], 32, 32, 3], 'stim')
    y = tf.placeholder(tf.float32, [par['batch_size'], par['layer_dims'][-1]],
                       'out')
    mask = tf.placeholder(tf.float32,
                          [par['batch_size'], par['layer_dims'][-1]], 'mask')
    rule = tf.placeholder(tf.float32, [par['batch_size'], par['n_tasks']],
                          'rulecue')
    gating = [
        tf.placeholder(tf.float32, [par['layer_dims'][n + 1]], 'gating')
        for n in range(par['n_layers'] - 1)
    ]
    droput_keep_pct = tf.placeholder(tf.float32, [], 'dropout')
    input_droput_keep_pct = tf.placeholder(tf.float32, [], 'input_dropout')

    # Set up stimulus
    if par['task'] == 'colored_mnist':
        stim = stimulus.Stimulus(labels_per_task=par['labels_per_task'],
                                 sep=par['separability'])
    else:
        stim = stimulus.Stimulus(labels_per_task=par['labels_per_task'])

    # Initialize accuracy records
    accuracy_full = []
    accuracy_grid = np.zeros((par['n_tasks'], par['n_tasks']))

    # Enter TensorFlow session
    with tf.Session() as sess:

        # Select CPU or GPU
        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, y, gating, mask, droput_keep_pct,
                          input_droput_keep_pct, rule)

        # Initialize variables
        sess.run(tf.global_variables_initializer())
        sess.run(model.reset_prev_vars)

        # Begin training loop, iterating over tasks
        for task in range(par['n_tasks']):

            # Create dictionary of gating signals applied to each hidden layer for this task
            gating_dict = {k: v for k, v in zip(gating, par['gating'][task])}

            # Create rule cue vector for this task
            rule_cue = np.zeros([par['batch_size'], par['n_tasks']])
            rule_cue[:, task] = 1

            # Iterate over batches
            for i in range(par['n_train_batches']):

                # Make batch of training data
                stim_in, y_hat, mk = stim.make_batch(task, test=False)

                # Run the model using one of the available stabilization methods
                if par['stabilization'] == 'pathint':
                    _, _, loss, AL = sess.run([model.train_op, model.update_small_omega, model.task_loss, model.aux_loss], \
                        feed_dict={x:stim_in, y:y_hat, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], rule:rule_cue})
                elif par['stabilization'] == 'EWC':
                    _, loss, AL = sess.run([model.train_op, model.task_loss, model.aux_loss], \
                        feed_dict={x:stim_in, y:y_hat, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], rule:rule_cue})

                # Display network performance
                if i % 500 == 0:
                    print('Iter: ', i, 'Loss: ', loss, 'Aux Loss: ', AL)

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                sess.run(model.update_big_omega)
            elif par['stabilization'] == 'EWC':
                for _ in range(par['EWC_batch_divisor'] *
                               par['EWC_fisher_num_batches']):
                    stim_in, _, mk = stim.make_batch(task, test=False)
                    sess.run([model.update_big_omega], feed_dict = \
                        {x:stim_in, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], rule:rule_cue})

            # Reset the Adam Optimizer, and set the prev_weight values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)

            # Test the networks on all trained tasks
            num_test_reps = 10
            accuracy = np.zeros((task + 1))
            for test_task in range(task + 1):

                # Use appropriate gating and rule cues
                gating_dict = {
                    k: v
                    for k, v in zip(gating, par['gating'][test_task])
                }
                test_rule_cue = np.zeros([par['batch_size'], par['n_tasks']])
                test_rule_cue[:, test_task] = 1

                # Repeat the test as desired
                for r in range(num_test_reps):
                    stim_in, y_hat, mk = stim.make_batch(test_task, test=True)
                    acc = sess.run(model.accuracy, feed_dict={x:stim_in, y:y_hat, \
                        **gating_dict, mask:mk, droput_keep_pct:1.0, input_droput_keep_pct:1.0, rule:test_rule_cue})/num_test_reps
                    accuracy_grid[task, test_task] += acc
                    accuracy[test_task] += acc

            # Display network performance after testing is complete
            print('Task ', task, ' Mean ', np.mean(accuracy), ' First ',
                  accuracy[0], ' Last ', accuracy[-1])
            accuracy_full.append(np.mean(accuracy))

            # Reset weights between tasks if called upon
            if par['reset_weights']:
                sess.run(model.reset_weights)

        # Save model performance and parameters if desired
        if par['save_analysis']:
            save_results = {'task': task, 'accuracy': accuracy, 'accuracy_full': accuracy_full, \
                            'accuracy_grid': accuracy_grid, 'par': par}
            pickle.dump(save_results, open(par['save_dir'] + save_fn, 'wb'))

    print('\nModel execution complete.')

    # write accuracy full to text
    with open('accs_200_squared.txt', 'a') as f:
        f.write(str(accuracy_full[0]) + "\n")
Пример #2
0
def main(save_fn, gpu_id=None):

    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # train the convolutional layers with the CIFAR-10 dataset
    # otherwise, it will load the convolutional weights from the saved file
    if (par['task'] == 'cifar' or par['task']
            == 'imagenet') and par['train_convolutional_layers']:
        convolutional_layers.ConvolutionalLayers()

    print('\nRunning model.\n')

    # Reset TensorFlow graph
    tf.reset_default_graph()

    # Create placeholders for the model
    # input_data, target_data, gating, mask, dropout keep pct hidden layers, dropout keep pct input layers

    if par['task'] == 'mnist':
        x = tf.placeholder(tf.float32,
                           [par['batch_size'], par['layer_dims'][0]], 'stim')
    elif par['task'] == 'cifar' or par['task'] == 'imagenet':
        x = tf.placeholder(tf.float32, [par['batch_size'], 32, 32, 3], 'stim')
    y = tf.placeholder(tf.float32, [par['batch_size'], par['layer_dims'][-1]],
                       'out')
    mask = tf.placeholder(tf.float32,
                          [par['batch_size'], par['layer_dims'][-1]], 'mask')
    droput_keep_pct = tf.placeholder(tf.float32, [], 'dropout')
    input_droput_keep_pct = tf.placeholder(tf.float32, [], 'input_dropout')
    gating = [
        tf.placeholder(tf.float32, [par['layer_dims'][n + 1]], 'gating')
        for n in range(par['n_layers'] - 1)
    ]
    context_vector = tf.placeholder(tf.float32, [1, par['n_tasks']],
                                    'context_vector')

    stim = stimulus.Stimulus(labels_per_task=par['labels_per_task'])
    accuracy_full = []
    accuracy_grid = np.zeros((par['n_tasks'], par['n_tasks']))

    with tf.Session() as sess:

        if gpu_id is None:
            model = Model(x, y, gating, mask, droput_keep_pct,
                          input_droput_keep_pct, context_vector)
        else:
            with tf.device("/gpu:0"):
                model = Model(x, y, gating, mask, droput_keep_pct,
                              input_droput_keep_pct, context_vector)
        init = tf.global_variables_initializer()
        sess.run(init)
        t_start = time.time()
        sess.run(model.reset_prev_vars)

        for task in range(par['n_tasks']):

            cont_vect = np.zeros((1, par['n_tasks']), dtype=np.float32)
            cont_vect[0, task] = 1.

            # create dictionary of gating signals applied to each hidden layer for this task
            gating_dict = {k: v for k, v in zip(gating, par['gating'][task])}

            for i in range(par['n_train_batches']):

                # make batch of training data
                stim_in, y_hat, mk = stim.make_batch(task, test=False)

                if par['stabilization'] == 'pathint':

                    _, _, loss, AL, gl = sess.run([model.train_op, model.update_small_omega, model.task_loss, model.aux_loss, model.gate_loss], \
                        feed_dict = {x:stim_in, y:y_hat, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], context_vector:cont_vect})

                elif par['stabilization'] == 'EWC':
                    _,loss, AL, gl, weight_grads, h, entropy_loss = sess.run([model.train_op, model.task_loss, model.aux_loss, model.gate_loss,\
                        model.weight_grads, model.h, model.entropy_loss], feed_dict = \
                        {x:stim_in, y:y_hat, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], context_vector:cont_vect})

                if i // 500 == i / 500:
                    print('Iter: ', i, 'Loss: ', loss, 'Aux Loss: ', AL,
                          'gate loss ', gl, 'entropy loss', entropy_loss)

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                big_omegas = sess.run(
                    [model.update_big_omega, model.big_omega_var])
            elif par['stabilization'] == 'EWC':
                for n in range(par['EWC_fisher_num_batches']):
                    stim_in, y_hat, mk = stim.make_batch(task, test=False)
                    _, _ = sess.run([model.update_big_omega,model.big_omega_var], feed_dict = \
                        {x:stim_in, y:y_hat, **gating_dict, mask:mk, droput_keep_pct:1.0, \
                        input_droput_keep_pct:1.0, context_vector:cont_vect})
                    big_omegas = sess.run([model.big_omega_var])

                sess.run([model.reset_shunted_weights])

            # Reset the Adam Optimizer, and set the previous parater values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)

            # Test the netwroks on all trained tasks
            num_test_reps = 10
            accuracy = np.zeros((task + 1))
            for test_task in range(task + 1):
                cont_vect = np.zeros((1, par['n_tasks']), dtype=np.float32)
                cont_vect[0, test_task] = 1
                gating_dict = {
                    k: v
                    for k, v in zip(gating, par['gating'][test_task])
                }
                for r in range(num_test_reps):
                    stim_in, y_hat, mk = stim.make_batch(test_task, test=True)
                    acc = sess.run(model.accuracy, feed_dict={x:stim_in, y:y_hat, \
                        **gating_dict, mask:mk, droput_keep_pct:1.0, input_droput_keep_pct:1.0,\
                        context_vector:cont_vect})/num_test_reps
                    accuracy_grid[task, test_task] += acc
                    accuracy[test_task] += acc

            print('Task ', task, ' Mean ', np.mean(accuracy), ' First ',
                  accuracy[0], ' Last ', accuracy[-1])
            accuracy_full.append(np.mean(accuracy))

            # reset weights between tasks if called upon
            if par['reset_weights']:
                sess.run(model.reset_weights)

            above_zeros = []
            for i in range(len(h)):
                above_zeros.append(
                    np.float32(np.sum(h[i], axis=0, keepdims=True) > 1e-16))
                print('mean h above zero ', np.mean(above_zeros[i]))
            """
            for k in big_omegas[0].keys():
                plt.imshow(big_omegas[0][k], aspect = 'auto')
                plt.colorbar()
                plt.show()
                print(k, big_omegas[0][k].shape)
            """

        if par['save_analysis']:
            save_results = {'task': task, 'accuracy': accuracy, 'accuracy_full': accuracy_full, \
                            'accuracy_grid': accuracy_grid, 'big_omegas': big_omegas, 'par': par}
            pickle.dump(save_results, open(par['save_dir'] + save_fn, 'wb'))

    print('\nModel execution complete.')
def main(save_fn, gpu_id=None):
    """ Run supervised learning training """

    # Update all dependencies in parameters
    update_dependencies()

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # If desired, train the convolutional layers with the CIFAR datasets
    # Otherwise, the network will load convolutional weights from the saved file
    if (par['task'] == 'cifar' or par['task']
            == 'imagenet') and par['train_convolutional_layers']:
        convolutional_layers.ConvolutionalLayers()

    print('\nRunning model.\n')

    # Reset TensorFlow graph
    tf.reset_default_graph()

    # Create placeholders for the model
    if par['task'] == 'mnist':
        x = tf.placeholder(tf.float32,
                           [par['batch_size'], par['layer_dims'][0]], 'stim')
    elif par['task'] == 'cifar' or par['task'] == 'imagenet':
        x = tf.placeholder(tf.float32, [par['batch_size'], 32, 32, 3], 'stim')
    y = tf.placeholder(tf.float32, [par['batch_size'], par['layer_dims'][-1]],
                       'out')
    mask = tf.placeholder(tf.float32,
                          [par['batch_size'], par['layer_dims'][-1]], 'mask')
    rule = tf.placeholder(tf.float32, [par['batch_size'], par['n_tasks']],
                          'rulecue')
    gating = [
        tf.placeholder(tf.float32, [par['layer_dims'][n + 1]], 'gating')
        for n in range(par['n_layers'] - 1)
    ]
    droput_keep_pct = tf.placeholder(tf.float32, [], 'dropout')
    input_droput_keep_pct = tf.placeholder(tf.float32, [], 'input_dropout')

    # Set up stimulus
    stim = stimulus.Stimulus(labels_per_task=par['labels_per_task'])

    # Initialize accuracy records
    accuracy_full = []
    accuracy_grid = np.zeros((par['n_tasks'], par['n_tasks']))

    # Enter TensorFlow session
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # Select CPU or GPU
        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, y, gating, mask, droput_keep_pct,
                          input_droput_keep_pct, rule)

        # Initialize variables
        sess.run(tf.global_variables_initializer())
        sess.run(model.reset_prev_vars)

        # test if importance vals change
        prev_imp = tf.Variable(tf.zeros(par['layer_dims'][1]), trainable=False)

        # Begin training loop, iterating over tasks
        for task in range(par['n_tasks']):

            if par['gating_type'] is 'iXdG':
                # test if imp vals change
                sess.run(tf.assign(prev_imp, model.importance[1]))

                # Update the importance of each unit
                sess.run(model.update_importance)

                # Create gates by importance for each task
                sess.run(model.update_gates)
                curr_task_gate = {}
                for layer in range(1, len(par['layer_dims']) - 1):
                    copy = tf.Variable(tf.zeros(
                        tf.shape(model.curr_gate[layer])),
                                       trainable=False)
                    sess.run(tf.assign(copy, model.curr_gate[layer]))
                    curr_task_gate[layer] = sess.run(copy)
                model.gates[task] = curr_task_gate

                # testing gates
                if True:
                    if task == 0:
                        for var in model.variables_stabilization:
                            print(var.op.name)
                            print(var.get_shape())

                    assert sess.run(tf.count_nonzero(
                        model.curr_gate[1])) == 400
                    assert sess.run(
                        tf.equal(model.curr_gate[1],
                                 model.gates[task][1])).all()

                    if task > 0:
                        layer_test = 1
                        expected_active_units = round(
                            (1 - par['gate_pct']) *
                            par['layer_dims'][layer_test])

                        print('Gates equal?')
                        print(
                            sess.run(
                                tf.equal(model.gates[task][layer_test],
                                         model.gates[task -
                                                     1][layer_test])).all())
                        print('Importance vals equal?')
                        print(
                            sess.run(tf.equal(prev_imp,
                                              model.importance[1])).all())

                        # vals_imp, idxs_imp = tf.math.top_k(model.importance[layer_test], k=2000)
                        # print(sess.run(idxs_imp)[-400:])
                        # print(sess.run(vals_imp)[-400:])
                        # print(sess.run(model.curr_gate[layer_test]))
                        # _, idxs = tf.math.top_k(model.curr_gate[layer_test], k=expected_active_units)
                        # print(sess.run(idxs))

                        # assert sess.run(tf.equal(idxs, sess.run(idxs_imp)[-400:])).all()
                        # _, idxs_copy = tf.math.top_k(model.gates[task][layer_test], k=expected_active_units)
                        # _, idxs_prev = tf.math.top_k(model.gates[task-1][layer_test], k=expected_active_units)
                        # print(sess.run(tf.shape(model.curr_gate[layer_test])))
                        # print(sess.run(idxs_prev))
                        # print(sess.run(idxs_copy))

            # Create dictionary of gating signals applied to each hidden layer for this task
                gating_dict = {
                    k: v
                    for k, v in zip(gating, list(model.gates[task].values()))
                }
            else:
                gating_dict = {
                    k: v
                    for k, v in zip(gating, par['gating'][task])
                }

            # Create rule cue vector for this task
            rule_cue = np.zeros([par['batch_size'], par['n_tasks']])
            rule_cue[:, task] = 1

            # Iterate over batches
            for i in range(par['n_train_batches']):

                # Make batch of training data
                stim_in, y_hat, mk = stim.make_batch(task, test=False)

                # Run the model using one of the available stabilization methods
                if par['stabilization'] == 'pathint':
                    _, _, loss, AL = sess.run([model.train_op, model.update_small_omega, model.task_loss, model.aux_loss], \
                        feed_dict={x:stim_in, y:y_hat, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], rule:rule_cue})
                elif par['stabilization'] == 'EWC':
                    _, loss, AL = sess.run([model.train_op, model.task_loss, model.aux_loss], \
                        feed_dict={x:stim_in, y:y_hat, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], rule:rule_cue})

                # Display network performance
                # if i%500 == 0:
                if i % 9 == 0:
                    print('Iter: ', i, 'Loss: ', loss, 'Aux Loss: ', AL)

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                sess.run(model.update_big_omega)
            elif par['stabilization'] == 'EWC':
                for _ in range(par['EWC_batch_divisor'] *
                               par['EWC_fisher_num_batches']):
                    stim_in, _, mk = stim.make_batch(task, test=False)
                    sess.run([model.update_big_omega], feed_dict = \
                        {x:stim_in, **gating_dict, mask:mk, droput_keep_pct:par['drop_keep_pct'], \
                        input_droput_keep_pct:par['input_drop_keep_pct'], rule:rule_cue})

            # Reset the Adam Optimizer, and set the prev_weight values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)
            """
            # update unit importance values
            if par['gating_type'] is 'iXdG':
                sess.run(model.update_importance)
            """

            # Test the networks on all trained tasks
            num_test_reps = 10
            accuracy = np.zeros((task + 1))
            for test_task in range(task + 1):

                # Use appropriate gating and rule cues
                if par['gating_type'] is 'iXdG':
                    gating_dict = {
                        k: v
                        for k, v in zip(gating,
                                        list(model.gates[test_task].values()))
                    }
                else:
                    gating_dict = {
                        k: v
                        for k, v in zip(gating, par['gating'][test_task])
                    }
                test_rule_cue = np.zeros([par['batch_size'], par['n_tasks']])
                test_rule_cue[:, test_task] = 1

                # Repeat the test as desired
                for r in range(num_test_reps):
                    stim_in, y_hat, mk = stim.make_batch(test_task, test=True)
                    acc = sess.run(model.accuracy, feed_dict={x:stim_in, y:y_hat, \
                        **gating_dict, mask:mk, droput_keep_pct:1.0, input_droput_keep_pct:1.0, rule:test_rule_cue})/num_test_reps
                    accuracy_grid[task, test_task] += acc
                    accuracy[test_task] += acc

            # Display network performance after testing is complete
            print('Task ', task, ' Mean ', np.mean(accuracy), ' First ',
                  accuracy[0], ' Last ', accuracy[-1])
            accuracy_full.append(np.mean(accuracy))

            # Reset weights between tasks if called upon
            if par['reset_weights']:
                sess.run(model.reset_weights)

        # Save model performance and parameters if desired
        if par['save_analysis']:
            save_results = {'task': task, 'accuracy': accuracy, 'accuracy_full': accuracy_full, \
                            'accuracy_grid': accuracy_grid, 'par': par}
            pickle.dump(save_results, open(par['save_dir'] + save_fn, 'wb'))

    print('\nModel execution complete.')
Пример #4
0
def main(save_fn=None, gpu_id=None):

    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # train the convolutional layers with the CIFAR-10 dataset
    # otherwise, it will load the convolutional weights from the saved file
    if (par['task'] == 'cifar' or par['task']
            == 'imagenet') and par['train_convolutional_layers']:
        convolutional_layers.ConvolutionalLayers()

    print('\nRunning model.\n')

    # Reset TensorFlow graph
    tf.reset_default_graph()

    # Create placeholders for the model
    # input_data, target_data, gating, mask

    x = tf.placeholder(
        tf.float32, [par['num_time_steps'], par['batch_size'], par['n_input']],
        'stim')
    target = tf.placeholder(
        tf.float32,
        [par['num_time_steps'], par['batch_size'], par['n_output']], 'out')
    mask = tf.placeholder(tf.float32,
                          [par['num_time_steps'], par['batch_size']], 'mask')
    gating = tf.placeholder(tf.float32, [par['n_hidden']], 'gating')

    stim = stimulus.MultiStimulus()
    accuracy_full = []
    accuracy_grid = np.zeros((par['n_tasks'], par['n_tasks']))


    key_info = ['synapse_config','spike_cost','weight_cost','entropy_cost','omega_c','omega_xi',\
        'constrain_input_weights','num_sublayers','n_hidden','noise_rnn_sd','learning_rate','gating_type', 'gate_pct']
    print('Key info')
    for k in key_info:
        print(k, ' ', par[k])

    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:

        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, target, gating, mask)

        sess.run(tf.global_variables_initializer())
        t_start = time.time()
        sess.run(model.reset_prev_vars)

        for task in range(0, par['n_tasks']):

            for i in range(par['n_train_batches']):

                # make batch of training data
                name, stim_in, y_hat, mk, _ = stim.generate_trial(task)

                if par['stabilization'] == 'pathint':
                    _, _, loss, AL, spike_loss, ent_loss, output = sess.run([model.train_op, \
                        model.update_small_omega, model.task_loss, model.aux_loss, model.spike_loss, \
                        model.entropy_loss, model.output], \
                        feed_dict = {x:stim_in, target:y_hat, gating:par['gating'][task], mask:mk})
                    sess.run([model.reset_rnn_weights])
                    if loss < 0.005 and AL < 0.0004 + 0.0002 * task:
                        break

                elif par['stabilization'] == 'EWC':
                    _, loss, AL = sess.run([model.train_op, model.task_loss, model.aux_loss], feed_dict = \
                        {x:stim_in, target:y_hat, gating:par['gating'][task], mask:mk})

                if i % 100 == 0:
                    acc = get_perf(y_hat, output, mk)
                    print('Iter ', i, 'Task name ', name, ' accuracy', acc, ' loss ', loss, ' aux loss', AL, ' spike loss', spike_loss, \
                        ' entropy loss', ent_loss)

            # Test all tasks at the end of each learning session
            num_reps = 10
            for (task_prime, r) in product(range(task + 1), range(num_reps)):

                # make batch of training data
                name, stim_in, y_hat, mk, _ = stim.generate_trial(task_prime)

                output, _ = sess.run([model.output, model.syn_x_hist],
                                     feed_dict={
                                         x: stim_in,
                                         gating: par['gating'][task_prime]
                                     })
                acc = get_perf(y_hat, output, mk)
                accuracy_grid[task, task_prime] += acc / num_reps

            print('Accuracy grid after task {}:'.format(task))
            print(accuracy_grid[task, :])
            print('')

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                big_omegas = sess.run(
                    [model.update_big_omega, model.big_omega_var])
            elif par['stabilization'] == 'EWC':
                for n in range(par['EWC_fisher_num_batches']):
                    name, stim_in, y_hat, mk, _ = stim.generate_trial(task)
                    big_omegas = sess.run([model.update_big_omega,model.big_omega_var], feed_dict = \
                        {x:stim_in, target:y_hat, gating:par['gating'][task], mask:mk})

            # Reset the Adam Optimizer, and set the previous parater values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)

            # reset weights between tasks if called upon
            if par['reset_weights']:
                sess.run(model.reset_weights)

        if par['save_analysis']:
            save_results = {'task': task, 'accuracy': accuracy, 'accuracy_full': accuracy_full, \
                            'accuracy_grid': accuracy_grid, 'big_omegas': big_omegas, 'par': par}
            pickle.dump(save_results, open(par['save_dir'] + save_fn, 'wb'))

    print('\nModel execution complete.')