def train_task_sequence(model, sess, args):
    """
    Train and evaluate LLL system such that we only see a example once
    Args:
    Returns:
        dict    A dictionary containing mean and stds for the experiment
    """
    # List to store accuracy for each run
    runs = []

    batch_size = args.batch_size

    if model.imp_method in {'A-GEM', 'MER'} or 'ER-' in model.imp_method:
        use_episodic_memory = True
    else:
        use_episodic_memory = False

    # Loop over number of runs to average over
    for runid in range(args.num_runs):
        print('\t\tRun %d:' % (runid))

        # Initialize the random seeds
        np.random.seed(args.random_seed + runid)

        time_start = time.time()
        # Load the permute mnist dataset
        datasets = construct_permute_mnist(model.num_tasks)
        print('Data loading time: {}'.format(time.time() - time_start))

        episodic_mem_size = args.mem_size * model.num_tasks * TOTAL_CLASSES

        # Initialize all the variables in the model
        sess.run(tf.global_variables_initializer())

        # Run the init ops
        model.init_updates(sess)

        # List to store accuracies for a run
        evals = []

        # List to store the classes that we have so far - used at test time
        test_labels = np.arange(TOTAL_CLASSES)

        if use_episodic_memory:
            # Reserve a space for episodic memory
            episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE])
            episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES])
            count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32)
            episodic_filled_counter = 0
            examples_seen_so_far = 0

        # Mask for softmax
        # Since all the classes are present in all the tasks so nothing to mask
        logit_mask = np.ones(TOTAL_CLASSES)

        if model.imp_method == 'PNN':
            pnn_train_phase = np.array(np.zeros(model.num_tasks),
                                       dtype=np.bool)
            pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES])

        if COUNT_VIOLATIONS:
            violation_count = np.zeros(model.num_tasks)
            vc = 0

        # Store the projection matrices for each task
        proj_matrices = generate_projection_matrix(
            model.num_tasks,
            feature_dim=model.subspace_proj.get_shape()[0],
            qr=QR)
        # Check the sanity of the generated matrices
        unit_test_projection_matrices(proj_matrices)

        # TODO: Temp for gradients check
        prev_task_grads = []
        # Training loop for all the tasks
        for task in range(len(datasets)):
            print('\t\tTask %d:' % (task))

            # If not the first task then restore weights from previous task
            if (task > 0 and model.imp_method != 'PNN'):
                model.restore(sess)

            if MULTI_TASK:
                if task == 0:
                    # Extract training images and labels for the current task
                    task_train_images = datasets[task]['train']['images']
                    task_train_labels = datasets[task]['train']['labels']
                    sample_weights = np.ones([task_train_labels.shape[0]],
                                             dtype=np.float32)
                    total_train_examples = task_train_images.shape[0]
                    # Randomly suffle the training examples
                    perm = np.arange(total_train_examples)
                    np.random.shuffle(perm)
                    train_x = task_train_images[perm][:args.examples_per_task]
                    train_y = task_train_labels[perm][:args.examples_per_task]
                    task_sample_weights = sample_weights[
                        perm][:args.examples_per_task]
                    for t_ in range(1, len(datasets)):
                        task_train_images = datasets[t_]['train']['images']
                        task_train_labels = datasets[t_]['train']['labels']
                        sample_weights = np.ones([task_train_labels.shape[0]],
                                                 dtype=np.float32)
                        total_train_examples = task_train_images.shape[0]
                        # Randomly suffle the training examples
                        perm = np.arange(total_train_examples)
                        np.random.shuffle(perm)
                        train_x = np.concatenate(
                            (train_x,
                             task_train_images[perm][:args.examples_per_task]),
                            axis=0)
                        train_y = np.concatenate(
                            (train_y,
                             task_train_labels[perm][:args.examples_per_task]),
                            axis=0)
                        task_sample_weights = np.concatenate(
                            (task_sample_weights,
                             sample_weights[perm][:args.examples_per_task]),
                            axis=0)

                    perm = np.arange(train_x.shape[0])
                    np.random.shuffle(perm)
                    train_x = train_x[perm]
                    train_y = train_y[perm]
                    task_sample_weights = task_sample_weights[perm]
                else:
                    # Skip training for this task
                    continue

            else:
                # Extract training images and labels for the current task
                task_train_images = datasets[task]['train']['images']
                task_train_labels = datasets[task]['train']['labels']
                # Assign equal weights to all the examples
                task_sample_weights = np.ones([task_train_labels.shape[0]],
                                              dtype=np.float32)
                total_train_examples = task_train_images.shape[0]
                # Randomly suffle the training examples
                perm = np.arange(total_train_examples)
                np.random.shuffle(perm)
                train_x = task_train_images[perm][:args.examples_per_task]
                train_y = task_train_labels[perm][:args.examples_per_task]
                task_sample_weights = task_sample_weights[
                    perm][:args.examples_per_task]

            print('Received {} images, {} labels at task {}'.format(
                train_x.shape[0], train_y.shape[0], task))

            # Array to store accuracies when training for task T
            ftask = []

            num_train_examples = train_x.shape[0]

            # Train a task observing sequence of data
            if args.train_single_epoch:
                num_iters = (num_train_examples + batch_size - 1) // batch_size
            else:
                num_iters = args.train_iters

            # Training loop for task T
            for iters in range(num_iters):

                if args.train_single_epoch and not args.cross_validate_mode:
                    if (iters < 10) or (iters < 100
                                        and iters % 10 == 0) or (iters % 100
                                                                 == 0):
                        # Snapshot the current performance across all tasks after each mini-batch
                        fbatch = test_task_sequence(model, sess, datasets,
                                                    args.online_cross_val,
                                                    proj_matrices)
                        ftask.append(fbatch)

                offset = (iters * batch_size) % num_train_examples
                if (offset + batch_size <= num_train_examples):
                    residual = batch_size
                else:
                    residual = num_train_examples - offset

                if model.imp_method == 'PNN':
                    pnn_train_phase[:] = False
                    pnn_train_phase[task] = True
                    feed_dict = {
                        model.x:
                        train_x[offset:offset + batch_size],
                        model.y_[task]:
                        train_y[offset:offset + batch_size],
                        model.sample_weights:
                        task_sample_weights[offset:offset + batch_size],
                        model.training_iters:
                        num_iters,
                        model.train_step:
                        iters,
                        model.keep_prob:
                        1.0,
                        model.learning_rate:
                        args.learning_rate
                    }
                    train_phase_dict = {
                        m_t: i_t
                        for (m_t,
                             i_t) in zip(model.train_phase, pnn_train_phase)
                    }
                    logit_mask_dict = {
                        m_t: i_t
                        for (m_t,
                             i_t) in zip(model.output_mask, pnn_logit_mask)
                    }
                    feed_dict.update(train_phase_dict)
                    feed_dict.update(logit_mask_dict)
                else:
                    feed_dict = {
                        model.x:
                        train_x[offset:offset + batch_size],
                        model.y_:
                        train_y[offset:offset + batch_size],
                        model.sample_weights:
                        task_sample_weights[offset:offset + batch_size],
                        model.training_iters:
                        num_iters,
                        model.train_step:
                        iters,
                        model.keep_prob:
                        1.0,
                        model.output_mask:
                        logit_mask,
                        model.train_phase:
                        True,
                        model.learning_rate:
                        args.learning_rate
                    }

                if model.imp_method == 'VAN':
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                elif model.imp_method == 'PNN':
                    feed_dict[model.task_id] = task
                    _, loss = sess.run(
                        [model.train[task], model.unweighted_entropy[task]],
                        feed_dict=feed_dict)

                elif model.imp_method == 'FTR_EXT':
                    if task == 0:
                        _, loss = sess.run([model.train, model.reg_loss],
                                           feed_dict=feed_dict)
                    else:
                        _, loss = sess.run(
                            [model.train_classifier, model.reg_loss],
                            feed_dict=feed_dict)

                elif model.imp_method == 'EWC':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher],
                                 feed_dict=feed_dict)
                    # Update fisher after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        sess.run(model.set_running_fisher)
                        sess.run(model.reset_tmp_fisher)

                    _, _, loss = sess.run(
                        [model.set_tmp_fisher, model.train, model.reg_loss],
                        feed_dict=feed_dict)

                elif model.imp_method == 'PI':
                    _, _, _, loss = sess.run([
                        model.weights_old_ops_grouped, model.train,
                        model.update_small_omega, model.reg_loss
                    ],
                                             feed_dict=feed_dict)

                elif model.imp_method == 'MAS':
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                elif model.imp_method == 'PROJ-ANCHOR':
                    if task == 0:
                        feed_dict[model.subspace_proj] = proj_matrices[task]
                        _, loss = sess.run([model.train, model.reg_loss],
                                           feed_dict=feed_dict)
                        reg = 0.0
                    else:
                        # Store the gradients in the orthogonal compliment
                        feed_dict[model.subspace_proj] = np.eye(
                            proj_matrices[task].shape[0]) - proj_matrices[task]
                        _, reg = sess.run(
                            [model.store_ref_grads, model.anchor_loss],
                            feed_dict=feed_dict)
                        feed_dict[model.subspace_proj] = proj_matrices[task]
                        _, loss = sess.run(
                            [model.train_subspace_proj, model.reg_loss],
                            feed_dict=feed_dict)

                elif model.imp_method == 'PROJ-SUBSPACE-GP':
                    if task == 0:
                        feed_dict[model.subspace_proj] = proj_matrices[task]
                        _, loss = sess.run([model.train, model.reg_loss],
                                           feed_dict=feed_dict)
                        reg = 0.0
                    else:
                        # Store the gradients in the orthogonal compliment
                        feed_dict[model.subspace_proj] = np.eye(
                            proj_matrices[task].shape[0]) - proj_matrices[task]
                        sess.run(model.store_ref_grads, feed_dict=feed_dict)
                        feed_dict[model.subspace_proj] = proj_matrices[task]
                        _, loss = sess.run(
                            [model.train_gp, model.gp_total_loss],
                            feed_dict=feed_dict)

                elif model.imp_method == 'SUBSPACE-PROJ':
                    feed_dict[model.subspace_proj] = proj_matrices[task]
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                elif model.imp_method == 'A-GEM':
                    if task == 0:
                        # Normal application of gradients
                        _, loss = sess.run(
                            [model.train_first_task, model.agem_loss],
                            feed_dict=feed_dict)
                    else:
                        ## Compute and store the reference gradients on the previous tasks
                        if episodic_filled_counter <= args.eps_mem_batch:
                            mem_sample_mask = np.arange(
                                episodic_filled_counter)
                        else:
                            # Sample a random subset from episodic memory buffer
                            mem_sample_mask = np.random.choice(
                                episodic_filled_counter,
                                args.eps_mem_batch,
                                replace=False
                            )  # Sample without replacement so that we don't sample an example more than once
                        # Store the reference gradient
                        sess.run(model.store_ref_grads,
                                 feed_dict={
                                     model.x: episodic_images[mem_sample_mask],
                                     model.y_:
                                     episodic_labels[mem_sample_mask],
                                     model.keep_prob: 1.0,
                                     model.output_mask: logit_mask,
                                     model.train_phase: True,
                                     model.learning_rate: args.learning_rate
                                 })
                        if COUNT_VIOLATIONS:
                            vc, _, loss = sess.run([
                                model.violation_count,
                                model.train_subseq_tasks, model.agem_loss
                            ],
                                                   feed_dict=feed_dict)
                        else:
                            # Compute the gradient for current task and project if need be
                            _, loss = sess.run(
                                [model.train_subseq_tasks, model.agem_loss],
                                feed_dict=feed_dict)
                    # Put the batch in the ring buffer
                    update_fifo_buffer(train_x[offset:offset + residual],
                                       train_y[offset:offset + residual],
                                       episodic_images, episodic_labels,
                                       np.arange(TOTAL_CLASSES), args.mem_size,
                                       count_cls, episodic_filled_counter)

                elif model.imp_method == 'RWALK':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher],
                                 feed_dict=feed_dict)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                    # Update fisher and importance score after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        # Update the importance score using distance in riemannian manifold
                        sess.run(model.update_big_omega_riemann)
                        # Now that the score is updated, compute the new value for running Fisher
                        sess.run(model.set_running_fisher)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                        # Reset the delta_L
                        sess.run([model.reset_small_omega])

                    _, _, _, _, loss = sess.run([
                        model.set_tmp_fisher, model.weights_old_ops_grouped,
                        model.train, model.update_small_omega, model.reg_loss
                    ],
                                                feed_dict=feed_dict)

                elif model.imp_method == 'MER':
                    mem_filled_so_far = examples_seen_so_far if (
                        examples_seen_so_far < episodic_mem_size
                    ) else episodic_mem_size
                    if mem_filled_so_far < args.eps_mem_batch:
                        er_mem_indices = np.arange(mem_filled_so_far)
                    else:
                        er_mem_indices = np.random.choice(mem_filled_so_far,
                                                          args.eps_mem_batch,
                                                          replace=False)
                    np.random.shuffle(er_mem_indices)
                    mer_episodic_x_batch, mer_episodic_y_batch = episodic_images[
                        er_mem_indices], episodic_labels[er_mem_indices]
                    sess.run(model.store_theta_i_not_w)
                    for mer_x, mer_y in zip(mer_episodic_x_batch,
                                            mer_episodic_y_batch):
                        feed_dict = {
                            model.x: np.expand_dims(mer_x, axis=0),
                            model.y_: np.expand_dims(mer_y, axis=0),
                            model.training_iters: num_iters,
                            model.train_step: iters,
                            model.keep_prob: 1.0,
                            model.output_mask: logit_mask,
                            model.train_phase: True,
                            model.learning_rate: args.learning_rate
                        }
                        sess.run(model.train, feed_dict=feed_dict)
                    feed_dict = {
                        model.x: train_x[offset:offset + residual],
                        model.y_: train_y[offset:offset + residual],
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.train_phase: True,
                        model.learning_rate: MER_S * args.learning_rate
                    }
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)
                    sess.run(model.with_in_batch_reptile_update,
                             feed_dict={model.mer_beta: MER_GAMMA})

                    # Store the examples in episodic memory using reservior sampling
                    for er_x, er_y_ in zip(train_x[offset:offset + residual],
                                           train_y[offset:offset + residual]):
                        update_reservior(er_x, er_y_, episodic_images,
                                         episodic_labels, episodic_mem_size,
                                         examples_seen_so_far)
                        examples_seen_so_far += 1

                elif model.imp_method == 'ER-Reservoir':
                    mem_filled_so_far = examples_seen_so_far if (
                        examples_seen_so_far < episodic_mem_size
                    ) else episodic_mem_size
                    if mem_filled_so_far < args.eps_mem_batch:
                        er_mem_indices = np.arange(mem_filled_so_far)
                    else:
                        er_mem_indices = np.random.choice(mem_filled_so_far,
                                                          args.eps_mem_batch,
                                                          replace=False)
                    np.random.shuffle(er_mem_indices)
                    er_train_x_batch = np.concatenate(
                        (episodic_images[er_mem_indices],
                         train_x[offset:offset + residual]),
                        axis=0)
                    er_train_y_batch = np.concatenate(
                        (episodic_labels[er_mem_indices],
                         train_y[offset:offset + residual]),
                        axis=0)
                    feed_dict = {
                        model.x: er_train_x_batch,
                        model.y_: er_train_y_batch,
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.train_phase: True,
                        model.learning_rate: args.learning_rate
                    }
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)
                    for er_x, er_y_ in zip(train_x[offset:offset + residual],
                                           train_y[offset:offset + residual]):
                        update_reservior(er_x, er_y_, episodic_images,
                                         episodic_labels, episodic_mem_size,
                                         examples_seen_so_far)
                        examples_seen_so_far += 1

                elif model.imp_method == 'ER-Ringbuffer':
                    mem_filled_so_far = episodic_filled_counter if (
                        episodic_filled_counter <= episodic_mem_size
                    ) else episodic_mem_size
                    er_mem_indices = np.arange(mem_filled_so_far) if (
                        mem_filled_so_far <= args.eps_mem_batch
                    ) else np.random.choice(
                        mem_filled_so_far, args.eps_mem_batch, replace=False)
                    er_train_x_batch = np.concatenate(
                        (episodic_images[er_mem_indices],
                         train_x[offset:offset + residual]),
                        axis=0)
                    er_train_y_batch = np.concatenate(
                        (episodic_labels[er_mem_indices],
                         train_y[offset:offset + residual]),
                        axis=0)
                    feed_dict = {
                        model.x: er_train_x_batch,
                        model.y_: er_train_y_batch,
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.learning_rate: args.learning_rate
                    }
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                    # Put the batch in the FIFO ring buffer
                    update_fifo_buffer(train_x[offset:offset + residual],
                                       train_y[offset:offset + residual],
                                       episodic_images, episodic_labels,
                                       np.arange(TOTAL_CLASSES), args.mem_size,
                                       count_cls, episodic_filled_counter)

                elif model.imp_method == 'ER-SUBSPACE':
                    # Zero out all the grads
                    sess.run([model.reset_er_subspace_grads])
                    # Accumulate grads for all the tasks
                    if task > 0:
                        # Randomly pick a task to replay
                        tt = np.squeeze(
                            np.random.choice(np.arange(task), 1,
                                             replace=False))
                        mem_offset = tt * args.mem_size * TOTAL_CLASSES
                        er_mem_indices = np.arange(
                            mem_offset,
                            mem_offset + args.mem_size * TOTAL_CLASSES)
                        np.random.shuffle(er_mem_indices)
                        er_train_x_batch = episodic_images[er_mem_indices]
                        er_train_y_batch = episodic_labels[er_mem_indices]
                        feed_dict = {
                            model.x: er_train_x_batch,
                            model.y_: er_train_y_batch,
                            model.training_iters: num_iters,
                            model.train_step: iters,
                            model.keep_prob: 1.0,
                            model.output_mask: logit_mask,
                            model.task_id: task + 1,
                            model.learning_rate: args.learning_rate
                        }
                        feed_dict[model.subspace_proj] = proj_matrices[tt]
                        sess.run(model.accum_er_subspace_grads,
                                 feed_dict=feed_dict)
                    # Train on the current task
                    feed_dict = {
                        model.x: train_x[offset:offset + residual],
                        model.y_: train_y[offset:offset + residual],
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.task_id: task + 1,
                        model.learning_rate: args.learning_rate
                    }
                    feed_dict[model.subspace_proj] = proj_matrices[task]
                    if args.maintain_orthogonality:
                        _, loss = sess.run(
                            [model.accum_er_subspace_grads, model.reg_loss],
                            feed_dict=feed_dict)
                        sess.run(model.train_stiefel,
                                 feed_dict={
                                     model.learning_rate: args.learning_rate
                                 })
                    else:
                        _, _, loss = sess.run([
                            model.train_er_subspace,
                            model.accum_er_subspace_grads, model.reg_loss
                        ],
                                              feed_dict=feed_dict)

                    # Put the batch in the FIFO ring buffer
                    update_fifo_buffer(train_x[offset:offset + residual],
                                       train_y[offset:offset + residual],
                                       episodic_images, episodic_labels,
                                       np.arange(TOTAL_CLASSES), args.mem_size,
                                       count_cls, episodic_filled_counter)

                if (iters % 100 == 0):
                    print('Step {:d} {:.3f}'.format(iters, loss))
                    #print('Step {:d}\t CE: {:.3f}\t Reg: {:.3f}\t TL: {:.3f}'.format(iters, entropy, reg, loss))
                    #print('Step {:d}\t Reg: {:.3f}\t TL: {:.3f}'.format(iters, reg, loss))

                if (math.isnan(loss)):
                    print('ERROR: NaNs NaNs Nans!!!')
                    sys.exit(0)

            print('\t\t\t\tTraining for Task%d done!' % (task))

            if model.imp_method == 'SUBSPACE-PROJ' and GRAD_CHECK:
                # TODO: Compute the average gradient for the task at \theta^*
                bbatch_size = 100
                grad_sum = []
                for iiters in range(train_x.shape[0] // bbatch_size):
                    offset = iiters * bbatch_size
                    feed_dict = {
                        model.x: train_x[offset:offset + bbatch_size],
                        model.y_: train_y[offset:offset + bbatch_size],
                        model.keep_prob: 1.0,
                        model.train_phase: True,
                        model.subspace_proj: proj_matrices[task],
                        model.output_mask: logit_mask,
                        model.learning_rate: args.learning_rate
                    }
                    grad_vars, train_vars = sess.run(
                        [model.reg_gradients_vars, model.trainable_vars],
                        feed_dict=feed_dict)
                    for v in range(len(train_vars)):
                        if iiters == 0:
                            grad_sum.append(grad_vars[v][0])
                        else:
                            grad_sum[v] += (grad_vars[v][0] -
                                            grad_sum[v]) / iiters

                prev_task_grads.append(grad_sum)

            # Upaate the episodic memory filled counter
            if use_episodic_memory:
                episodic_filled_counter += args.mem_size * TOTAL_CLASSES

            if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS:
                violation_count[task] = vc
                print('Task {}: Violation Count: {}'.format(
                    task, violation_count))
                sess.run(model.reset_violation_count, feed_dict=feed_dict)

            # Compute the inter-task updates, Fisher/ importance scores etc
            # Don't calculate the task updates for the last task
            if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY:
                model.task_updates(sess, task, task_train_images,
                                   np.arange(TOTAL_CLASSES))
                print('\t\t\t\tTask updates after Task%d done!' % (task))

            if args.train_single_epoch and not args.cross_validate_mode:
                fbatch = test_task_sequence(model, sess, datasets, False,
                                            proj_matrices)
                ftask.append(fbatch)
                ftask = np.array(ftask)
            else:
                if MEASURE_PERF_ON_EPS_MEMORY:
                    eps_mem = {
                        'images': episodic_images,
                        'labels': episodic_labels,
                    }
                    # Measure perf on episodic memory
                    ftask = test_task_sequence(model, sess, eps_mem,
                                               args.online_cross_val,
                                               proj_matrices)
                else:
                    # List to store accuracy for all the tasks for the current trained model
                    ftask = test_task_sequence(model, sess, datasets,
                                               args.online_cross_val,
                                               proj_matrices)
                    print('Task: {}, Acc: {}'.format(task, ftask))

            # Store the accuracies computed at task T in a list
            evals.append(ftask)

            # Reset the optimizer
            model.reset_optimizer(sess)

            #-> End for loop task

        runs.append(np.array(evals))
        # End for loop runid

    runs = np.array(runs)

    return runs
Exemple #2
0
def train_task_sequence(model, sess, args):
    """
    Train and evaluate LLL system such that we only see a example once
    Args:
    Returns:
        dict    A dictionary containing mean and stds for the experiment
    """
    # List to store accuracy for each run
    runs = []

    batch_size = args.batch_size

    if model.imp_method == 'A-GEM' or model.imp_method == 'ER':
        use_episodic_memory = True
    else:
        use_episodic_memory = False

    # Loop over number of runs to average over
    for runid in range(args.num_runs):
        print('\t\tRun %d:'%(runid))

        # Initialize the random seeds
        np.random.seed(args.random_seed+runid)

        # Load the permute mnist dataset
        datasets = construct_permute_mnist(model.num_tasks)

        episodic_mem_size = args.mem_size*model.num_tasks*TOTAL_CLASSES

        # Initialize all the variables in the model
        sess.run(tf.global_variables_initializer())

        # Run the init ops
        model.init_updates(sess)

        # List to store accuracies for a run
        evals = []

        # List to store the classes that we have so far - used at test time
        test_labels = np.arange(TOTAL_CLASSES)

        if use_episodic_memory:
            # Reserve a space for episodic memory
            episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE])
            episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES])
            count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32)
            episodic_filled_counter = 0
            examples_seen_so_far = 0

        # Mask for softmax
        # Since all the classes are present in all the tasks so nothing to mask
        logit_mask = np.ones(TOTAL_CLASSES)
        if model.imp_method == 'PNN':
            pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool)
            pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES])

        if COUNT_VIOLATIONS:
            violation_count = np.zeros(model.num_tasks)
            vc = 0

        # Training loop for all the tasks
        for task in range(len(datasets)):
            print('\t\tTask %d:'%(task))

            # If not the first task then restore weights from previous task
            if(task > 0 and model.imp_method != 'PNN'):
                model.restore(sess)

            # Extract training images and labels for the current task
            task_train_images = datasets[task]['train']['images']
            task_train_labels = datasets[task]['train']['labels']

            # If multi_task is set the train using datasets of all the tasks
            if MULTI_TASK:
                if task == 0:
                    for t_ in range(1, len(datasets)):
                        task_train_images = np.concatenate((task_train_images, datasets[t_]['train']['images']), axis=0)
                        task_train_labels = np.concatenate((task_train_labels, datasets[t_]['train']['labels']), axis=0)
                else:
                    # Skip training for this task
                    continue

            # Assign equal weights to all the examples
            task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32)
            total_train_examples = task_train_images.shape[0]
            # Randomly suffle the training examples
            perm = np.arange(total_train_examples)
            np.random.shuffle(list(perm))
            train_x = task_train_images[perm][:args.examples_per_task]
            train_y = task_train_labels[perm][:args.examples_per_task]
            task_sample_weights = task_sample_weights[perm][:args.examples_per_task]
            
            print('Received {} images, {} labels at task {}'.format(train_x.shape[0], train_y.shape[0], task))

            # Array to store accuracies when training for task T
            ftask = []
            
            num_train_examples = train_x.shape[0]

            # Train a task observing sequence of data
            if args.train_single_epoch:
                num_iters = num_train_examples // batch_size
            else:
                num_iters = args.train_iters

            # Training loop for task T
            for iters in range(num_iters):

                if args.train_single_epoch and not args.cross_validate_mode:
                    if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0):
                        # Snapshot the current performance across all tasks after each mini-batch
                        fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val)
                        ftask.append(fbatch)

                offset = (iters * batch_size) % (num_train_examples - batch_size)
                residual = batch_size

                if model.imp_method == 'PNN':
                    pnn_train_phase[:] = False
                    pnn_train_phase[task] = True
                    feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_[task]: train_y[offset:offset+batch_size], 
                            model.sample_weights: task_sample_weights[offset:offset+batch_size],
                            model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0}
                    train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)}
                    logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)}
                    feed_dict.update(train_phase_dict)
                    feed_dict.update(logit_mask_dict)
                else:
                    feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_: train_y[offset:offset+batch_size], 
                            model.sample_weights: task_sample_weights[offset:offset+batch_size],
                            model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, 
                            model.output_mask: logit_mask, model.train_phase: True}

                if model.imp_method == 'VAN':
                    _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'PNN':
                    feed_dict[model.task_id] = task
                    _, loss = sess.run([model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict)

                elif model.imp_method == 'FTR_EXT':
                    if task == 0:
                        _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)
                    else:
                        _, loss = sess.run([model.train_classifier, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'EWC':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher], feed_dict=feed_dict)
                    # Update fisher after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        sess.run(model.set_running_fisher)
                        sess.run(model.reset_tmp_fisher)
                    
                    _, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'PI':
                    _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega, 
                                              model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'MAS':
                    _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'A-GEM':
                    if task == 0:
                        # Normal application of gradients
                        _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict)
                    else:
                        ## Compute and store the reference gradients on the previous tasks
                        if episodic_filled_counter <= args.eps_mem_batch:
                            mem_sample_mask = np.arange(episodic_filled_counter)
                        else:
                            # Sample a random subset from episodic memory buffer
                            mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once
                        # Store the reference gradient
                        sess.run(model.store_ref_grads, feed_dict={model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask],
                            model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True})
                        if COUNT_VIOLATIONS:
                            vc, _, loss = sess.run([model.violation_count, model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict)
                        else:
                            # Compute the gradient for current task and project if need be
                            _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict)
                    # Put the batch in the ring buffer
                    for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]):
                        cls = np.unique(np.nonzero(er_y_))[-1]
                        # Write the example at the location pointed by count_cls[cls]
                        cls_to_index_map = cls
                        with_in_task_offset = args.mem_size  * cls_to_index_map
                        mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter
                        episodic_images[mem_index] = er_x
                        episodic_labels[mem_index] = er_y_
                        count_cls[cls] = (count_cls[cls] + 1) % args.mem_size

                elif model.imp_method == 'RWALK':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher], feed_dict=feed_dict)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                    # Update fisher and importance score after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        # Update the importance score using distance in riemannian manifold   
                        sess.run(model.update_big_omega_riemann)
                        # Now that the score is updated, compute the new value for running Fisher
                        sess.run(model.set_running_fisher)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                        # Reset the delta_L
                        sess.run([model.reset_small_omega])

                    _, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped, 
                        model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'ER':
                    mem_filled_so_far = examples_seen_so_far if (examples_seen_so_far < episodic_mem_size) else episodic_mem_size
                    if mem_filled_so_far < args.eps_mem_batch:
                        er_mem_indices = np.arange(mem_filled_so_far)
                    else:
                        er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False)
                    np.random.shuffle(er_mem_indices)
                    # Train on a batch of episodic memory first
                    er_train_x_batch = np.concatenate((episodic_images[er_mem_indices], train_x[offset:offset+residual]), axis=0)
                    er_train_y_batch = np.concatenate((episodic_labels[er_mem_indices], train_y[offset:offset+residual]), axis=0)
                    feed_dict = {model.x: er_train_x_batch, model.y_: er_train_y_batch,
                        model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0,
                        model.output_mask: logit_mask, model.train_phase: True}
                    _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)
                    for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]):
                        update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far)
                        examples_seen_so_far += 1

                if (iters % 100 == 0):
                    print('Step {:d} {:.3f}'.format(iters, loss))

                if (math.isnan(loss)):
                    print('ERROR: NaNs NaNs Nans!!!')
                    sys.exit(0)

            print('\t\t\t\tTraining for Task%d done!'%(task))

            # Upaate the episodic memory filled counter
            if use_episodic_memory:
                episodic_filled_counter += args.mem_size * TOTAL_CLASSES

            if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS:
                violation_count[task] = vc
                print('Task {}: Violation Count: {}'.format(task, violation_count))
                sess.run(model.reset_violation_count, feed_dict=feed_dict)

            # Compute the inter-task updates, Fisher/ importance scores etc
            # Don't calculate the task updates for the last task
            if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY:
                model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES))
                print('\t\t\t\tTask updates after Task%d done!'%(task))

            if args.train_single_epoch and not args.cross_validate_mode: 
                fbatch = test_task_sequence(model, sess, datasets, False)
                ftask.append(fbatch)
                ftask = np.array(ftask)
            else:
                if MEASURE_PERF_ON_EPS_MEMORY:
                    eps_mem = {
                            'images': episodic_images, 
                            'labels': episodic_labels,
                            }
                    # Measure perf on episodic memory
                    ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val)
                else:
                    # List to store accuracy for all the tasks for the current trained model
                    ftask = test_task_sequence(model, sess, datasets, args.online_cross_val)
            
            # Store the accuracies computed at task T in a list
            evals.append(ftask)

            # Reset the optimizer
            model.reset_optimizer(sess)

            #-> End for loop task

        runs.append(np.array(evals))
        # End for loop runid

    runs = np.array(runs)

    return runs
Exemple #3
0
def train_task_sequence(model, sess, cross_validate_mode, train_single_epoch,
                        eval_single_head, do_sampling, is_herding,
                        mem_per_class, train_iters, batch_size, num_runs,
                        online_cross_val, random_seed):
    """
    Train and evaluate LLL system such that we only see a example once
    Args:
    Returns:
        dict    A dictionary containing mean and stds for the experiment
    """
    # List to store accuracy for each run
    runs = []

    # Loop over number of runs to average over
    for runid in range(num_runs):
        print('\t\tRun %d:' % (runid))

        # Initialize the random seeds
        np.random.seed(random_seed + runid)

        # Load the permute mnist dataset
        datasets = construct_permute_mnist(model.num_tasks)

        episodic_mem_size = mem_per_class * model.num_tasks * TOTAL_CLASSES

        # Initialize all the variables in the model
        sess.run(tf.global_variables_initializer())

        # Run the init ops
        model.init_updates(sess)

        # List to store accuracies for a run
        evals = []

        # List to store the classes that we have so far - used at test time
        test_labels = np.arange(TOTAL_CLASSES)

        if model.imp_method == 'S-GEM':
            # List to store the episodic memories of the previous tasks
            task_based_memory = []

        if model.imp_method == 'A-GEM':
            # Reserve a space for episodic memory
            episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE])
            episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES])
            episodic_filled_counter = 0

        if do_sampling:
            # List to store important samples from the previous tasks
            last_task_x = None
            last_task_y_ = None

        # Mask for softmax
        # Since all the classes are present in all the tasks so nothing to mask
        logit_mask = np.ones(TOTAL_CLASSES)
        if model.imp_method == 'PNN':
            pnn_train_phase = np.array(np.zeros(model.num_tasks),
                                       dtype=np.bool)
            pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES])

        if COUNT_VIOLATIONS:
            violation_count = np.zeros(model.num_tasks)
            vc = 0

        # Training loop for all the tasks
        for task in range(len(datasets)):
            print('\t\tTask %d:' % (task))

            # If not the first task then restore weights from previous task
            if (task > 0 and model.imp_method != 'PNN'):
                model.restore(sess)

            # If sampling flag is set append the previous datasets
            if (do_sampling and task > 0):
                task_train_images, task_train_labels = concatenate_datasets(
                    datasets[task]['train']['images'],
                    datasets[task]['train']['labels'], last_task_x,
                    last_task_y_)
            else:
                # Extract training images and labels for the current task
                task_train_images = datasets[task]['train']['images']
                task_train_labels = datasets[task]['train']['labels']

            # If multi_task is set the train using datasets of all the tasks
            if MULTI_TASK:
                if task == 0:
                    for t_ in range(1, len(datasets)):
                        task_train_images = np.concatenate(
                            (task_train_images,
                             datasets[t_]['train']['images']),
                            axis=0)
                        task_train_labels = np.concatenate(
                            (task_train_labels,
                             datasets[t_]['train']['labels']),
                            axis=0)
                else:
                    # Skip training for this task
                    continue
            print('Received {} images, {} labels at task {}'.format(
                task_train_images.shape[0], task_train_labels.shape[0], task))

            # Declare variables to store sample importance if sampling flag is set
            if do_sampling:
                # Get the sample weighting
                task_sample_weights = get_sample_weights(
                    task_train_labels, test_labels)
            else:
                # Assign equal weights to all the examples
                task_sample_weights = np.ones([task_train_labels.shape[0]],
                                              dtype=np.float32)

            num_train_examples = task_train_images.shape[0]

            # Train a task observing sequence of data
            if train_single_epoch:
                num_iters = num_train_examples // batch_size
            else:
                num_iters = train_iters

            # Randomly suffle the training examples
            perm = np.arange(num_train_examples)
            np.random.shuffle(perm)
            train_x = task_train_images[perm]
            train_y = task_train_labels[perm]
            task_sample_weights = task_sample_weights[perm]

            # Array to store accuracies when training for task T
            ftask = []

            # Training loop for task T
            for iters in range(num_iters):

                if train_single_epoch and not cross_validate_mode:
                    if (iters < 10) or (iters < 100
                                        and iters % 10 == 0) or (iters % 100
                                                                 == 0):
                        # Snapshot the current performance across all tasks after each mini-batch
                        fbatch = test_task_sequence(
                            model,
                            sess,
                            datasets,
                            online_cross_val,
                            eval_single_head=eval_single_head)
                        ftask.append(fbatch)

                offset = (iters * batch_size) % (num_train_examples -
                                                 batch_size)

                if model.imp_method == 'PNN':
                    pnn_train_phase[:] = False
                    pnn_train_phase[task] = True
                    feed_dict = {
                        model.x:
                        train_x[offset:offset + batch_size],
                        model.y_[task]:
                        train_y[offset:offset + batch_size],
                        model.sample_weights:
                        task_sample_weights[offset:offset + batch_size],
                        model.training_iters:
                        num_iters,
                        model.train_step:
                        iters,
                        model.keep_prob:
                        1.0
                    }
                    train_phase_dict = {
                        m_t: i_t
                        for (m_t,
                             i_t) in zip(model.train_phase, pnn_train_phase)
                    }
                    logit_mask_dict = {
                        m_t: i_t
                        for (m_t,
                             i_t) in zip(model.output_mask, pnn_logit_mask)
                    }
                    feed_dict.update(train_phase_dict)
                    feed_dict.update(logit_mask_dict)
                else:
                    feed_dict = {
                        model.x:
                        train_x[offset:offset + batch_size],
                        model.y_:
                        train_y[offset:offset + batch_size],
                        model.sample_weights:
                        task_sample_weights[offset:offset + batch_size],
                        model.training_iters:
                        num_iters,
                        model.train_step:
                        iters,
                        model.keep_prob:
                        1.0,
                        model.output_mask:
                        logit_mask,
                        model.train_phase:
                        True
                    }

                if model.imp_method == 'VAN':
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                elif model.imp_method == 'PNN':
                    feed_dict[model.task_id] = task
                    _, loss = sess.run(
                        [model.train[task], model.unweighted_entropy[task]],
                        feed_dict=feed_dict)

                elif model.imp_method == 'FTR_EXT':
                    if task == 0:
                        _, loss = sess.run([model.train, model.reg_loss],
                                           feed_dict=feed_dict)
                    else:
                        _, loss = sess.run(
                            [model.train_classifier, model.reg_loss],
                            feed_dict=feed_dict)

                elif model.imp_method == 'EWC':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher],
                                 feed_dict=feed_dict)
                    # Update fisher after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        sess.run(model.set_running_fisher)
                        sess.run(model.reset_tmp_fisher)

                    _, _, loss = sess.run(
                        [model.set_tmp_fisher, model.train, model.reg_loss],
                        feed_dict=feed_dict)

                elif model.imp_method == 'PI':
                    _, _, _, loss = sess.run([
                        model.weights_old_ops_grouped, model.train,
                        model.update_small_omega, model.reg_loss
                    ],
                                             feed_dict=feed_dict)

                elif model.imp_method == 'MAS':
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                elif model.imp_method == 'S-GEM':
                    if task == 0:
                        # Normal application of gradients
                        _, loss = sess.run(
                            [model.train_first_task, model.agem_loss],
                            feed_dict=feed_dict)
                    else:
                        # Randomly sample a task from the previous tasks
                        prev_task = np.random.randint(0, task)
                        # Store the reference gradient
                        sess.run(model.store_ref_grads,
                                 feed_dict={
                                     model.x:
                                     task_based_memory[prev_task]['images'],
                                     model.y_:
                                     task_based_memory[prev_task]['labels'],
                                     model.keep_prob:
                                     1.0,
                                     model.output_mask:
                                     logit_mask,
                                     model.train_phase:
                                     True
                                 })
                        # Compute the gradient for current task and project if need be
                        _, loss = sess.run(
                            [model.train_subseq_tasks, model.agem_loss],
                            feed_dict=feed_dict)

                elif model.imp_method == 'A-GEM':
                    if task == 0:
                        # Normal application of gradients
                        _, loss = sess.run(
                            [model.train_first_task, model.agem_loss],
                            feed_dict=feed_dict)
                    else:
                        ## Compute and store the reference gradients on the previous tasks
                        if KEEP_EPISODIC_MEMORY_FULL:
                            mem_sample_mask = np.random.choice(
                                episodic_mem_size,
                                EPS_MEM_BATCH_SIZE,
                                replace=False
                            )  # Sample without replacement so that we don't sample an example more than once
                        else:
                            if episodic_filled_counter <= EPS_MEM_BATCH_SIZE:
                                mem_sample_mask = np.arange(
                                    episodic_filled_counter)
                            else:
                                # Sample a random subset from episodic memory buffer
                                mem_sample_mask = np.random.choice(
                                    episodic_filled_counter,
                                    EPS_MEM_BATCH_SIZE,
                                    replace=False
                                )  # Sample without replacement so that we don't sample an example more than once
                        # Store the reference gradient
                        sess.run(model.store_ref_grads,
                                 feed_dict={
                                     model.x: episodic_images[mem_sample_mask],
                                     model.y_:
                                     episodic_labels[mem_sample_mask],
                                     model.keep_prob: 1.0,
                                     model.output_mask: logit_mask,
                                     model.train_phase: True
                                 })
                        if COUNT_VIOLATIONS:
                            vc, _, loss = sess.run([
                                model.violation_count,
                                model.train_subseq_tasks, model.agem_loss
                            ],
                                                   feed_dict=feed_dict)
                        else:
                            # Compute the gradient for current task and project if need be
                            _, loss = sess.run(
                                [model.train_subseq_tasks, model.agem_loss],
                                feed_dict=feed_dict)

                elif model.imp_method == 'RWALK':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher],
                                 feed_dict=feed_dict)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                    # Update fisher and importance score after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        # Update the importance score using distance in riemannian manifold
                        sess.run(model.update_big_omega_riemann)
                        # Now that the score is updated, compute the new value for running Fisher
                        sess.run(model.set_running_fisher)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                        # Reset the delta_L
                        sess.run([model.reset_small_omega])

                    _, _, _, _, loss = sess.run([
                        model.set_tmp_fisher, model.weights_old_ops_grouped,
                        model.train, model.update_small_omega, model.reg_loss
                    ],
                                                feed_dict=feed_dict)

                if (iters % 500 == 0):
                    print('Step {:d} {:.3f}'.format(iters, loss))

                if (math.isnan(loss)):
                    print('ERROR: NaNs NaNs Nans!!!')
                    sys.exit(0)

            print('\t\t\t\tTraining for Task%d done!' % (task))

            if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS:
                violation_count[task] = vc
                print('Task {}: Violation Count: {}'.format(
                    task, violation_count))
                sess.run(model.reset_violation_count, feed_dict=feed_dict)

            # Compute the inter-task updates, Fisher/ importance scores etc
            # Don't calculate the task updates for the last task
            if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY:
                model.task_updates(sess, task, task_train_images,
                                   np.arange(TOTAL_CLASSES))
                print('\t\t\t\tTask updates after Task%d done!' % (task))

                # If importance method is '*-GEM' then store the episodic memory for the task
                if 'GEM' in model.imp_method:
                    data_to_sample_from = {
                        'images': task_train_images,
                        'labels': task_train_labels,
                    }
                    if model.imp_method == 'S-GEM':
                        # Get the important samples from the current task
                        if is_herding:  # Sampling based on MoF
                            # Compute the features of training data
                            features_dim = model.image_feature_dim
                            features = np.zeros(
                                [num_train_examples, features_dim])
                            samples_at_a_time = 100
                            for i in range(num_train_examples //
                                           samples_at_a_time):
                                offset = i * samples_at_a_time
                                features[offset:offset +
                                         samples_at_a_time] = sess.run(
                                             model.features,
                                             feed_dict={
                                                 model.x:
                                                 task_train_images[
                                                     offset:offset +
                                                     samples_at_a_time],
                                                 model.y_:
                                                 task_train_labels[
                                                     offset:offset +
                                                     samples_at_a_time],
                                                 model.keep_prob:
                                                 1.0,
                                                 model.output_mask:
                                                 logit_mask,
                                                 model.train_phase:
                                                 False
                                             })
                            imp_images, imp_labels = sample_from_dataset_icarl(
                                data_to_sample_from, features,
                                np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS)
                        else:  # Random sampling
                            # Do the uniform sampling
                            importance_array = np.ones(num_train_examples,
                                                       dtype=np.float32)
                            imp_images, imp_labels = sample_from_dataset(
                                data_to_sample_from, importance_array,
                                np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS)
                        task_memory = {
                            'images': deepcopy(imp_images),
                            'labels': deepcopy(imp_labels),
                        }
                        task_based_memory.append(task_memory)

                    elif model.imp_method == 'A-GEM':
                        if is_herding:  # Sampling based on MoF
                            # Compute the features of training data
                            features_dim = model.image_feature_dim
                            features = np.zeros(
                                [num_train_examples, features_dim])
                            samples_at_a_time = 100
                            for i in range(num_train_examples //
                                           samples_at_a_time):
                                offset = i * samples_at_a_time
                                features[offset:offset +
                                         samples_at_a_time] = sess.run(
                                             model.features,
                                             feed_dict={
                                                 model.x:
                                                 task_train_images[
                                                     offset:offset +
                                                     samples_at_a_time],
                                                 model.y_:
                                                 task_train_labels[
                                                     offset:offset +
                                                     samples_at_a_time],
                                                 model.keep_prob:
                                                 1.0,
                                                 model.output_mask:
                                                 logit_mask,
                                                 model.train_phase:
                                                 False
                                             })
                            if KEEP_EPISODIC_MEMORY_FULL:
                                update_episodic_memory(
                                    data_to_sample_from,
                                    features,
                                    episodic_mem_size,
                                    task,
                                    episodic_images,
                                    episodic_labels,
                                    task_labels=np.arange(TOTAL_CLASSES),
                                    is_herding=True)
                            else:
                                imp_images, imp_labels = sample_from_dataset_icarl(
                                    data_to_sample_from, features,
                                    np.arange(TOTAL_CLASSES),
                                    SAMPLES_PER_CLASS)
                        else:  # Random sampling
                            # Do the uniform sampling
                            importance_array = np.ones(num_train_examples,
                                                       dtype=np.float32)
                            if KEEP_EPISODIC_MEMORY_FULL:
                                update_episodic_memory(data_to_sample_from,
                                                       importance_array,
                                                       episodic_mem_size, task,
                                                       episodic_images,
                                                       episodic_labels)
                            else:
                                imp_images, imp_labels = sample_from_dataset(
                                    data_to_sample_from, importance_array,
                                    np.arange(TOTAL_CLASSES),
                                    SAMPLES_PER_CLASS)
                        if not KEEP_EPISODIC_MEMORY_FULL:  # Fill the memory to always keep M/T samples per task
                            total_imp_samples = imp_images.shape[0]
                            eps_offset = task * total_imp_samples
                            episodic_images[eps_offset:eps_offset +
                                            total_imp_samples] = imp_images
                            episodic_labels[eps_offset:eps_offset +
                                            total_imp_samples] = imp_labels
                            episodic_filled_counter += total_imp_samples
                        # Inspect episodic memory
                        if DEBUG_EPISODIC_MEMORY:
                            # Which labels are present in the memory
                            unique_labels = np.unique(
                                np.nonzero(episodic_labels)[-1])
                            print(
                                'Unique Labels present in the episodic memory'.
                                format(unique_labels))
                            print('Labels count:')
                            for lbl in unique_labels:
                                print('Label {}: {} samples'.format(
                                    lbl,
                                    np.where(
                                        np.nonzero(episodic_labels)[-1] == lbl)
                                    [0].size))
                            # Is there any space which is not filled
                            print('Empty space: {}'.format(
                                np.where(
                                    np.sum(episodic_labels, axis=1) == 0)))
                        print('Episodic memory of {} images at task {} saved!'.
                              format(episodic_images.shape[0], task))

                # If sampling flag is set, store few of the samples from previous task
                if do_sampling:
                    # Do the uniform sampling/ only get examples from current task
                    importance_array = np.ones(
                        [datasets[task]['train']['images'].shape[0]],
                        dtype=np.float32)
                    # Get the important samples from the current task
                    imp_images, imp_labels = sample_from_dataset(
                        datasets[task]['train'], importance_array,
                        np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS)

                    if imp_images is not None:
                        if last_task_x is None:
                            last_task_x = imp_images
                            last_task_y_ = imp_labels
                        else:
                            last_task_x = np.concatenate(
                                (last_task_x, imp_images), axis=0)
                            last_task_y_ = np.concatenate(
                                (last_task_y_, imp_labels), axis=0)

                    # Delete the importance array now that you don't need it in the current run
                    del importance_array

                    print(
                        '\t\t\t\tEpisodic memory of {} is saved for Task {}!'.
                        format(imp_labels.shape[0], task))

            if train_single_epoch and not cross_validate_mode:
                fbatch = test_task_sequence(model,
                                            sess,
                                            datasets,
                                            False,
                                            eval_single_head=eval_single_head)
                ftask.append(fbatch)
                ftask = np.array(ftask)
            else:
                if MEASURE_PERF_ON_EPS_MEMORY:
                    eps_mem = {
                        'images': episodic_images,
                        'labels': episodic_labels,
                    }
                    # Measure perf on episodic memory
                    ftask = test_task_sequence(
                        model,
                        sess,
                        eps_mem,
                        online_cross_val,
                        eval_single_head=eval_single_head)
                else:
                    # List to store accuracy for all the tasks for the current trained model
                    ftask = test_task_sequence(
                        model,
                        sess,
                        datasets,
                        online_cross_val,
                        eval_single_head=eval_single_head)

            # Store the accuracies computed at task T in a list
            evals.append(ftask)

            # Reset the optimizer
            model.reset_optimizer(sess)

            #-> End for loop task

        runs.append(np.array(evals))
        # End for loop runid

    runs = np.array(runs)

    return runs
def train_task_sequence(model, sess, args):
    """
    Train and evaluate LLL system such that we only see a example once
    Args:
    Returns:
        dict    A dictionary containing mean and stds for the experiment
    """
    # List to store accuracy for each run
    runs = []

    batch_size = args.batch_size

    if model.imp_method == 'A-GEM' or model.imp_method == 'MER' or 'ER-' in model.imp_method:
        use_episodic_memory = True
    else:
        use_episodic_memory = False

    # Loop over number of runs to average over
    for runid in range(args.num_runs):
        print('\t\tRun %d:' % (runid))

        # Initialize the random seeds
        np.random.seed(args.random_seed + runid)

        time_start = time.time()
        # Load the permute mnist dataset
        datasets = construct_permute_mnist(model.num_tasks)
        time_end = time.time()
        time_spent = time_end - time_start
        print('Data loading time: {}'.format(time_spent))

        episodic_mem_size = args.mem_size * model.num_tasks * TOTAL_CLASSES

        # Initialize all the variables in the model
        sess.run(tf.global_variables_initializer())

        # Run the init ops
        model.init_updates(sess)

        # List to store accuracies for a run
        evals = []

        # List to store the classes that we have so far - used at test time
        test_labels = np.arange(TOTAL_CLASSES)

        if use_episodic_memory:
            # Reserve a space for episodic memory
            episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE])
            episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES])
            count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32)
            episodic_filled_counter = 0
            examples_seen_so_far = 0
            if model.imp_method == 'ER-Hindsight-Anchors':
                avg_img_vectors = np.zeros([TOTAL_CLASSES, INPUT_FEATURE_SIZE])
                anchor_images = np.zeros(
                    [model.num_tasks * TOTAL_CLASSES, INPUT_FEATURE_SIZE])
                anchor_labels = np.zeros(
                    [model.num_tasks * TOTAL_CLASSES, TOTAL_CLASSES])
                anchor_count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32)

        # Mask for softmax
        # Since all the classes are present in all the tasks so nothing to mask
        logit_mask = np.ones(TOTAL_CLASSES)

        if COUNT_VIOLATIONS:
            violation_count = np.zeros(model.num_tasks)
            vc = 0

        # Training loop for all the tasks
        for task in range(len(datasets)):
            print('\t\tTask %d:' % (task))

            anchors_counter = task * TOTAL_CLASSES  # 1 per class per task

            # If not the first task then restore weights from previous task
            if (task > 0):
                model.restore(sess)

            if MULTI_TASK:
                if task == 0:
                    # Extract training images and labels for the current task
                    task_train_images = datasets[task]['train']['images']
                    task_train_labels = datasets[task]['train']['labels']
                    sample_weights = np.ones([task_train_labels.shape[0]],
                                             dtype=np.float32)
                    total_train_examples = task_train_images.shape[0]
                    # Randomly suffle the training examples
                    perm = np.arange(total_train_examples)
                    np.random.shuffle(perm)
                    train_x = task_train_images[perm][:args.examples_per_task]
                    train_y = task_train_labels[perm][:args.examples_per_task]
                    task_sample_weights = sample_weights[
                        perm][:args.examples_per_task]
                    for t_ in range(1, len(datasets)):
                        task_train_images = datasets[t_]['train']['images']
                        task_train_labels = datasets[t_]['train']['labels']
                        sample_weights = np.ones([task_train_labels.shape[0]],
                                                 dtype=np.float32)
                        total_train_examples = task_train_images.shape[0]
                        # Randomly suffle the training examples
                        perm = np.arange(total_train_examples)
                        np.random.shuffle(perm)
                        train_x = np.concatenate(
                            (train_x,
                             task_train_images[perm][:args.examples_per_task]),
                            axis=0)
                        train_y = np.concatenate(
                            (train_y,
                             task_train_labels[perm][:args.examples_per_task]),
                            axis=0)
                        task_sample_weights = np.concatenate(
                            (task_sample_weights,
                             sample_weights[perm][:args.examples_per_task]),
                            axis=0)

                    perm = np.arange(train_x.shape[0])
                    np.random.shuffle(perm)
                    train_x = train_x[perm]
                    train_y = train_y[perm]
                    task_sample_weights = task_sample_weights[perm]
                else:
                    # Skip training for this task
                    continue

            else:
                # Extract training images and labels for the current task
                task_train_images = datasets[task]['train']['images']
                task_train_labels = datasets[task]['train']['labels']
                # Assign equal weights to all the examples
                task_sample_weights = np.ones([task_train_labels.shape[0]],
                                              dtype=np.float32)
                total_train_examples = task_train_images.shape[0]
                # Randomly suffle the training examples
                perm = np.arange(total_train_examples)
                np.random.shuffle(perm)
                train_x = task_train_images[perm][:args.examples_per_task]
                train_y = task_train_labels[perm][:args.examples_per_task]
                task_sample_weights = task_sample_weights[
                    perm][:args.examples_per_task]

            print('Received {} images, {} labels at task {}'.format(
                train_x.shape[0], train_y.shape[0], task))

            # Array to store accuracies when training for task T
            ftask = []

            num_train_examples = train_x.shape[0]

            # Train a task observing sequence of data
            if args.train_single_epoch:
                num_iters = (num_train_examples + batch_size - 1) // batch_size
            else:
                num_iters = args.train_iters

            # Training loop for task T
            for iters in range(num_iters):

                if args.train_single_epoch and not args.cross_validate_mode:
                    if (iters < 10) or (iters < 100
                                        and iters % 10 == 0) or (iters % 100
                                                                 == 0):
                        # Snapshot the current performance across all tasks after each mini-batch
                        fbatch = test_task_sequence(model, sess, datasets,
                                                    args.online_cross_val)
                        ftask.append(fbatch)

                offset = (iters * batch_size) % num_train_examples
                if (offset + batch_size <= num_train_examples):
                    residual = batch_size
                else:
                    residual = num_train_examples - offset

                feed_dict = {
                    model.x:
                    train_x[offset:offset + batch_size],
                    model.y_:
                    train_y[offset:offset + batch_size],
                    model.sample_weights:
                    task_sample_weights[offset:offset + batch_size],
                    model.training_iters:
                    num_iters,
                    model.train_step:
                    iters,
                    model.keep_prob:
                    1.0,
                    model.output_mask:
                    logit_mask,
                    model.learning_rate:
                    args.learning_rate
                }

                if model.imp_method == 'VAN':
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                elif model.imp_method == 'FTR_EXT':
                    if task == 0:
                        _, loss = sess.run([model.train, model.reg_loss],
                                           feed_dict=feed_dict)
                    else:
                        _, loss = sess.run(
                            [model.train_classifier, model.reg_loss],
                            feed_dict=feed_dict)

                elif model.imp_method == 'EWC':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher],
                                 feed_dict=feed_dict)
                    # Update fisher after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        sess.run(model.set_running_fisher)
                        sess.run(model.reset_tmp_fisher)

                    _, _, loss = sess.run(
                        [model.set_tmp_fisher, model.train, model.reg_loss],
                        feed_dict=feed_dict)

                elif model.imp_method == 'PI':
                    _, _, _, loss = sess.run([
                        model.weights_old_ops_grouped, model.train,
                        model.update_small_omega, model.reg_loss
                    ],
                                             feed_dict=feed_dict)

                elif model.imp_method == 'MAS':
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                elif model.imp_method == 'A-GEM':
                    if task == 0:
                        # Normal application of gradients
                        _, loss = sess.run(
                            [model.train_first_task, model.agem_loss],
                            feed_dict=feed_dict)
                    else:
                        ## Compute and store the reference gradients on the previous tasks
                        if episodic_filled_counter <= args.eps_mem_batch:
                            mem_sample_mask = np.arange(
                                episodic_filled_counter)
                        else:
                            # Sample a random subset from episodic memory buffer
                            mem_sample_mask = np.random.choice(
                                episodic_filled_counter,
                                args.eps_mem_batch,
                                replace=False
                            )  # Sample without replacement so that we don't sample an example more than once
                        # Store the reference gradient
                        sess.run(model.store_ref_grads,
                                 feed_dict={
                                     model.x: episodic_images[mem_sample_mask],
                                     model.y_:
                                     episodic_labels[mem_sample_mask],
                                     model.keep_prob: 1.0,
                                     model.output_mask: logit_mask,
                                     model.learning_rate: args.learning_rate
                                 })
                        if COUNT_VIOLATIONS:
                            vc, _, loss = sess.run([
                                model.violation_count,
                                model.train_subseq_tasks, model.agem_loss
                            ],
                                                   feed_dict=feed_dict)
                        else:
                            # Compute the gradient for current task and project if need be
                            _, loss = sess.run(
                                [model.train_subseq_tasks, model.agem_loss],
                                feed_dict=feed_dict)

                    # Put the batch in the FIFO ring buffer
                    update_fifo_buffer(train_x[offset:offset + residual],
                                       train_y[offset:offset + residual],
                                       episodic_images, episodic_labels,
                                       np.arange(TOTAL_CLASSES), args.mem_size,
                                       count_cls, episodic_filled_counter)

                elif model.imp_method == 'RWALK':
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher],
                                 feed_dict=feed_dict)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                    # Update fisher and importance score after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        # Update the importance score using distance in riemannian manifold
                        sess.run(model.update_big_omega_riemann)
                        # Now that the score is updated, compute the new value for running Fisher
                        sess.run(model.set_running_fisher)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                        # Reset the delta_L
                        sess.run([model.reset_small_omega])

                    _, _, _, _, loss = sess.run([
                        model.set_tmp_fisher, model.weights_old_ops_grouped,
                        model.train, model.update_small_omega, model.reg_loss
                    ],
                                                feed_dict=feed_dict)

                elif model.imp_method == 'ER-Reservoir':
                    mem_filled_so_far = examples_seen_so_far if (
                        examples_seen_so_far <= episodic_mem_size
                    ) else episodic_mem_size
                    er_mem_indices = np.arange(mem_filled_so_far) if (
                        mem_filled_so_far <= args.eps_mem_batch
                    ) else np.random.choice(
                        mem_filled_so_far, args.eps_mem_batch, replace=False)
                    np.random.shuffle(er_mem_indices)
                    er_train_x_batch = np.concatenate(
                        (episodic_images[er_mem_indices],
                         train_x[offset:offset + residual]),
                        axis=0)
                    er_train_y_batch = np.concatenate(
                        (episodic_labels[er_mem_indices],
                         train_y[offset:offset + residual]),
                        axis=0)
                    feed_dict = {
                        model.x: er_train_x_batch,
                        model.y_: er_train_y_batch,
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.learning_rate: args.learning_rate
                    }
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                    # Reservoir update
                    examples_seen_so_far = update_reservior(
                        train_x[offset:offset + residual],
                        train_y[offset:offset + residual], episodic_images,
                        episodic_labels, episodic_mem_size,
                        examples_seen_so_far)

                elif model.imp_method == 'MER':
                    mem_filled_so_far = examples_seen_so_far if (
                        examples_seen_so_far <= episodic_mem_size
                    ) else episodic_mem_size
                    mer_mem_indices = np.arange(mem_filled_so_far) if (
                        mem_filled_so_far <= args.eps_mem_batch
                    ) else np.random.choice(
                        mem_filled_so_far, args.eps_mem_batch, replace=False)
                    np.random.shuffle(mer_mem_indices)
                    mer_train_x_batch = episodic_images[mer_mem_indices]
                    mer_train_y_batch = episodic_labels[mer_mem_indices]
                    sess.run(model.store_theta_i_not_w)
                    for mer_x, mer_y in zip(mer_train_x_batch,
                                            mer_train_y_batch):
                        feed_dict = {
                            model.x: np.expand_dims(mer_x, axis=0),
                            model.y_: np.expand_dims(mer_y, axis=0),
                            model.training_iters: num_iters,
                            model.train_step: iters,
                            model.keep_prob: 1.0,
                            model.output_mask: logit_mask,
                            model.learning_rate: args.learning_rate
                        }
                        _, loss = sess.run([model.train, model.reg_loss],
                                           feed_dict=feed_dict)
                    feed_dict = {
                        model.x: train_x[offset:offset + residual],
                        model.y_: train_y[offset:offset + residual],
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.learning_rate: MER_S * args.learning_rate
                    }
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)
                    sess.run(model.with_in_batch_reptile_update,
                             feed_dict={model.mer_beta: MER_BETA
                                        })  # In the paper this is 'mer_gamma'
                    # Reservoir update
                    examples_seen_so_far = update_reservior(
                        train_x[offset:offset + residual],
                        train_y[offset:offset + residual], episodic_images,
                        episodic_labels, episodic_mem_size,
                        examples_seen_so_far)

                elif model.imp_method == 'ER-Ring':
                    mem_filled_so_far = episodic_filled_counter if (
                        episodic_filled_counter <= episodic_mem_size
                    ) else episodic_mem_size
                    er_mem_indices = np.arange(mem_filled_so_far) if (
                        mem_filled_so_far <= args.eps_mem_batch
                    ) else np.random.choice(
                        mem_filled_so_far, args.eps_mem_batch, replace=False)
                    er_train_x_batch = np.concatenate(
                        (episodic_images[er_mem_indices],
                         train_x[offset:offset + residual]),
                        axis=0)
                    er_train_y_batch = np.concatenate(
                        (episodic_labels[er_mem_indices],
                         train_y[offset:offset + residual]),
                        axis=0)
                    feed_dict = {
                        model.x: er_train_x_batch,
                        model.y_: er_train_y_batch,
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.learning_rate: args.learning_rate
                    }
                    _, loss = sess.run([model.train, model.reg_loss],
                                       feed_dict=feed_dict)

                    # Put the batch in the FIFO ring buffer
                    update_fifo_buffer(train_x[offset:offset + residual],
                                       train_y[offset:offset + residual],
                                       episodic_images, episodic_labels,
                                       np.arange(TOTAL_CLASSES), args.mem_size,
                                       count_cls, episodic_filled_counter)

                elif model.imp_method == 'ER-Hindsight-Anchors':
                    anchor_mem_indices = np.arange(anchors_counter) if (
                        anchors_counter <= args.eps_mem_batch
                    ) else np.random.choice(
                        anchors_counter, args.eps_mem_batch, replace=False)
                    mem_filled_so_far = episodic_filled_counter if (
                        episodic_filled_counter <= episodic_mem_size
                    ) else episodic_mem_size
                    er_mem_indices = np.arange(mem_filled_so_far) if (
                        mem_filled_so_far <= args.eps_mem_batch
                    ) else np.random.choice(
                        mem_filled_so_far, args.eps_mem_batch, replace=False)
                    er_train_x_batch = np.concatenate(
                        (episodic_images[er_mem_indices],
                         train_x[offset:offset + residual]),
                        axis=0)
                    er_train_y_batch = np.concatenate(
                        (episodic_labels[er_mem_indices],
                         train_y[offset:offset + residual]),
                        axis=0)
                    feed_dict = {
                        model.x: er_train_x_batch,
                        model.y_: er_train_y_batch,
                        model.training_iters: num_iters,
                        model.train_step: iters,
                        model.keep_prob: 1.0,
                        model.output_mask: logit_mask,
                        model.learning_rate: args.learning_rate
                    }
                    feed_dict[model.phi_hat_alpha] = 0.5
                    if task == 0:
                        # Just train on the current task dataset
                        phi_hat, _, loss = sess.run(
                            [model.phi_hat, model.train, model.reg_loss],
                            feed_dict=feed_dict)
                        anchor_loss = 0.0
                        task_loss = 0.0
                        # Update the FIFO buffer
                        update_fifo_buffer(train_x[offset:offset + residual],
                                           train_y[offset:offset + residual],
                                           anchor_images, anchor_labels,
                                           np.arange(TOTAL_CLASSES), 1,
                                           anchor_count_cls, anchors_counter)
                    else:
                        # Train on the anchoring loss
                        feed_dict[model.anchor_points] = anchor_images[
                            anchor_mem_indices]
                        phi_hat, _, anchor_loss, task_loss, loss = sess.run(
                            [
                                model.phi_hat, model.train_anchor,
                                model.anchor_loss, model.task_loss,
                                model.final_anchor_loss
                            ],
                            feed_dict=feed_dict)

                    # Update the average image vectors
                    update_avg_image_vectors(train_x[offset:offset + residual],
                                             train_y[offset:offset + residual],
                                             avg_img_vectors,
                                             running_alpha=0.9)

                    # Put the batch in the FIFO ring buffer
                    update_fifo_buffer(train_x[offset:offset + residual],
                                       train_y[offset:offset + residual],
                                       episodic_images, episodic_labels,
                                       np.arange(TOTAL_CLASSES), args.mem_size,
                                       count_cls, episodic_filled_counter)

                if (iters % 10 == 0):
                    print('Step {:d} {:.3f}'.format(iters, loss))

                if (math.isnan(loss)):
                    print('ERROR: NaNs NaNs Nans!!!')
                    sys.exit(0)

            print('\t\t\t\tTraining for Task%d done!' % (task))

            if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS:
                violation_count[task] = vc
                print('Task {}: Violation Count: {}'.format(
                    task, violation_count))
                sess.run(model.reset_violation_count, feed_dict=feed_dict)

            elif model.imp_method == 'ER-Hindsight-Anchors':
                if task == 0:
                    # Anchors are already populated
                    pass
                else:
                    anchor_x = np.zeros([TOTAL_CLASSES, INPUT_FEATURE_SIZE])
                    anchor_y = np.zeros([TOTAL_CLASSES, TOTAL_CLASSES])
                    task_labels = np.arange(TOTAL_CLASSES)
                    anchor_x, anchor_y = er_mem_update_hindsight(
                        model, sess, anchor_x, anchor_y, episodic_images,
                        episodic_labels, episodic_filled_counter, task_labels,
                        logit_mask, phi_hat, avg_img_vectors, args)

                    anchor_images[anchors_counter:anchors_counter +
                                  TOTAL_CLASSES] = normalize_tensors(anchor_x)
                    anchor_labels[anchors_counter:anchors_counter +
                                  TOTAL_CLASSES] = anchor_y

                # Reset the average image vectors
                avg_img_vectors[:] = 0.0

            # Upaate the episodic memory filled counter
            if use_episodic_memory:
                episodic_filled_counter += args.mem_size * TOTAL_CLASSES
                print('Unique labels in the episodic memory: {}'.format(
                    np.unique(np.nonzero(episodic_labels)[1])))
                print('Labels in the episodic memory: {}'.format(
                    np.nonzero(episodic_labels)[1]))

            # Compute the inter-task updates, Fisher/ importance scores etc
            # Don't calculate the task updates for the last task
            if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY:
                model.task_updates(sess, task, task_train_images,
                                   np.arange(TOTAL_CLASSES))
                print('\t\t\t\tTask updates after Task%d done!' % (task))

            if args.train_single_epoch and not args.cross_validate_mode:
                fbatch = test_task_sequence(model, sess, datasets, False)
                ftask.append(fbatch)
                ftask = np.array(ftask)
            else:
                if MEASURE_PERF_ON_EPS_MEMORY:
                    eps_mem = {
                        'images': episodic_images,
                        'labels': episodic_labels,
                    }
                    # Measure perf on episodic memory
                    ftask = test_task_sequence(model, sess, eps_mem,
                                               args.online_cross_val)
                else:
                    # List to store accuracy for all the tasks for the current trained model
                    ftask = test_task_sequence(model, sess, datasets,
                                               args.online_cross_val)
                    print('Task: {}, Acc: {}'.format(task, ftask))

            # Store the accuracies computed at task T in a list
            evals.append(ftask)

            # Reset the optimizer
            model.reset_optimizer(sess)

            #-> End for loop task

        runs.append(np.array(evals))
        # End for loop runid

    runs = np.array(runs)

    return runs