def train_task_sequence(model, sess, datasets, args): """ Train and evaluate LLL system such that we only see a example once Args: Returns: dict A dictionary containing mean and stds for the experiment """ # List to store accuracy for each run runs = [] task_labels_dataset = [] if model.imp_method == 'A-GEM' or model.imp_method == 'ER' or model.imp_method == 'MEGA' or model.imp_method == 'MEGAD': use_episodic_memory = True else: use_episodic_memory = False batch_size = args.batch_size # Loop over number of runs to average over for runid in range(args.num_runs): print('\t\tRun %d:'%(runid)) # Initialize the random seeds np.random.seed(args.random_seed+runid) # Get the task labels from the total number of tasks and full label space task_labels = [] classes_per_task = TOTAL_CLASSES// NUM_TASKS total_classes = classes_per_task * model.num_tasks if args.online_cross_val: label_array = np.arange(total_classes) else: class_label_offset = K_FOR_CROSS_VAL * classes_per_task label_array = np.arange(class_label_offset, total_classes+class_label_offset) np.random.shuffle(label_array) for tt in range(model.num_tasks): tt_offset = tt*classes_per_task task_labels.append(list(label_array[tt_offset:tt_offset+classes_per_task])) print('Task: {}, Labels:{}'.format(tt, task_labels[tt])) # Store the task labels task_labels_dataset.append(task_labels) # Set episodic memory size episodic_mem_size = args.mem_size * total_classes # Initialize all the variables in the model sess.run(tf.global_variables_initializer()) # Run the init ops model.init_updates(sess) # List to store accuracies for a run evals = [] # List to store the classes that we have so far - used at test time test_labels = [] if use_episodic_memory: # Reserve a space for episodic memory episodic_images = np.zeros([episodic_mem_size, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS]) episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) episodic_filled_counter = 0 nd_logit_mask = np.zeros([model.num_tasks, TOTAL_CLASSES]) count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) episodic_filled_counter = 0 examples_seen_so_far = 0 # Mask for softmax logit_mask = np.zeros(TOTAL_CLASSES) if COUNT_VIOLATONS: violation_count = np.zeros(model.num_tasks) vc = 0 # Training loop for all the tasks for task in range(len(task_labels)): print('\t\tTask %d:'%(task)) # If not the first task then restore weights from previous task if(task > 0 and model.imp_method != 'PNN'): model.restore(sess) if model.imp_method == 'PNN': pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) pnn_train_phase[task] = True pnn_logit_mask = np.zeros([model.num_tasks, TOTAL_CLASSES]) # If not in the cross validation mode then concatenate the train and validation sets task_tr_images, task_tr_labels = load_task_specific_data(datasets[0]['train'], task_labels[task]) task_val_images, task_val_labels = load_task_specific_data(datasets[0]['validation'], task_labels[task]) task_train_images, task_train_labels = concatenate_datasets(task_tr_images, task_tr_labels, task_val_images, task_val_labels) # If multi_task is set then train using all the datasets of all the tasks if MULTI_TASK: if task == 0: for t_ in range(1, len(task_labels)): task_tr_images, task_tr_labels = load_task_specific_data(datasets[0]['train'], task_labels[t_]) task_train_images = np.concatenate((task_train_images, task_tr_images), axis=0) task_train_labels = np.concatenate((task_train_labels, task_tr_labels), axis=0) else: # Skip training for this task continue print('Received {} images, {} labels at task {}'.format(task_train_images.shape[0], task_train_labels.shape[0], task)) print('Unique labels in the task: {}'.format(np.unique(np.nonzero(task_train_labels)[1]))) # Test for the tasks that we've seen so far test_labels += task_labels[task] # Assign equal weights to all the examples task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) num_train_examples = task_train_images.shape[0] logit_mask[:] = 0 # Train a task observing sequence of data if args.train_single_epoch: # Ceiling operation num_iters = (num_train_examples + batch_size - 1) // batch_size if args.cross_validate_mode: logit_mask[task_labels[task]] = 1.0 else: num_iters = args.train_iters # Set the mask only once before starting the training for the task logit_mask[task_labels[task]] = 1.0 if MULTI_TASK: logit_mask[:] = 1.0 # Randomly suffle the training examples perm = np.arange(num_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm] train_y = task_train_labels[perm] task_sample_weights = task_sample_weights[perm] # Array to store accuracies when training for task T ftask = [] # Number of iterations after which convergence is checked convergence_iters = int(num_iters * MEASURE_CONVERGENCE_AFTER) # Training loop for task T for iters in range(num_iters): if args.train_single_epoch and not args.cross_validate_mode and not MULTI_TASK: if (iters <= 20) or (iters > 20 and iters % 50 == 0): # Snapshot the current performance across all tasks after each mini-batch fbatch = test_task_sequence(model, sess, datasets[0]['test'], task_labels, task) ftask.append(fbatch) if model.imp_method == 'PNN': pnn_train_phase[:] = False pnn_train_phase[task] = True pnn_logit_mask[:] = 0 pnn_logit_mask[task][task_labels[task]] = 1.0 elif model.imp_method == 'A-GEM' or model.imp_method == 'MEGA' or model.imp_method == 'MEGAD': nd_logit_mask[:] = 0 nd_logit_mask[task][task_labels[task]] = 1.0 else: # Set the output labels over which the model needs to be trained logit_mask[:] = 0 logit_mask[task_labels[task]] = 1.0 if args.train_single_epoch: offset = iters * batch_size if (offset+batch_size <= num_train_examples): residual = batch_size else: residual = num_train_examples - offset if model.imp_method == 'PNN': feed_dict = {model.x: train_x[offset:offset+residual], model.y_[task]: train_y[offset:offset+residual], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5} train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} feed_dict.update(train_phase_dict) feed_dict.update(logit_mask_dict) else: feed_dict = {model.x: train_x[offset:offset+residual], model.y_: train_y[offset:offset+residual], model.sample_weights: task_sample_weights[offset:offset+residual], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5, model.train_phase: True} else: offset = (iters * batch_size) % (num_train_examples - batch_size) if model.imp_method == 'PNN': feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_[task]: train_y[offset:offset+batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5} train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} feed_dict.update(train_phase_dict) feed_dict.update(logit_mask_dict) else: feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_: train_y[offset:offset+batch_size], model.sample_weights: task_sample_weights[offset:offset+batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5, model.train_phase: True} if model.imp_method == 'VAN': feed_dict[model.output_mask] = logit_mask _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PNN': _, loss = sess.run([model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) elif model.imp_method == 'FTR_EXT': feed_dict[model.output_mask] = logit_mask if task == 0: _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) else: _, loss = sess.run([model.train_classifier, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'EWC' or model.imp_method == 'M-EWC': feed_dict[model.output_mask] = logit_mask # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Update fisher after every few iterations if (iters + 1) % model.fisher_update_after == 0: sess.run(model.set_running_fisher) sess.run(model.reset_tmp_fisher) if (iters >= convergence_iters) and (model.imp_method == 'M-EWC'): _, _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.set_tmp_fisher, model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) else: _, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PI': feed_dict[model.output_mask] = logit_mask _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'MAS': feed_dict[model.output_mask] = logit_mask _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'A-GEM': if task == 0: nd_logit_mask[:] = 0 nd_logit_mask[task][task_labels[task]] = 1.0 logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} feed_dict.update(logit_mask_dict) feed_dict[model.mem_batch_size] = batch_size # Normal application of gradients _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks # Set the mask for all the previous tasks so far nd_logit_mask[:] = 0 for tt in range(task): nd_logit_mask[tt][task_labels[tt]] = 1.0 if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange(episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient ref_feed_dict = {model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.train_phase: True} logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} ref_feed_dict.update(logit_mask_dict) ref_feed_dict[model.mem_batch_size] = float(len(mem_sample_mask)) sess.run([model.store_ref_grads], feed_dict=ref_feed_dict) # Compute the gradient for current task and project if need be nd_logit_mask[:] = 0 nd_logit_mask[task][task_labels[task]] = 1.0 logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} feed_dict.update(logit_mask_dict) feed_dict[model.mem_batch_size] = batch_size if COUNT_VIOLATONS: vc, _, loss = sess.run([model.violation_count, model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) else: _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) # Put the batch in the ring buffer for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): cls = np.unique(np.nonzero(er_y_))[-1] # Write the example at the location pointed by count_cls[cls] cls_to_index_map = np.where(np.array(task_labels[task]) == cls)[0][0] with_in_task_offset = args.mem_size * cls_to_index_map mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter episodic_images[mem_index] = er_x episodic_labels[mem_index] = er_y_ count_cls[cls] = (count_cls[cls] + 1) % args.mem_size elif model.imp_method == 'MEGA': if task == 0: nd_logit_mask[:] = 0 nd_logit_mask[task][task_labels[task]] = 1.0 logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} feed_dict.update(logit_mask_dict) feed_dict[model.mem_batch_size] = batch_size # Normal application of gradients _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks # Set the mask for all the previous tasks so far nd_logit_mask[:] = 0 for tt in range(task): nd_logit_mask[tt][task_labels[tt]] = 1.0 if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange(episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient ref_feed_dict = {model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.train_phase: True} logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} ref_feed_dict.update(logit_mask_dict) ref_feed_dict[model.mem_batch_size] = float(len(mem_sample_mask)) sess.run([model.store_ref_grads, model.store_ref_loss], feed_dict=ref_feed_dict) # Compute the gradient for current task and project if need be nd_logit_mask[:] = 0 nd_logit_mask[task][task_labels[task]] = 1.0 logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} feed_dict.update(logit_mask_dict) feed_dict[model.mem_batch_size] = batch_size # Training _, loss, theta = sess.run([model.train_subseq_tasks, model.agem_loss, model.theta], feed_dict=feed_dict) # Put the batch in the ring buffer for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): cls = np.unique(np.nonzero(er_y_))[-1] # Write the example at the location pointed by count_cls[cls] cls_to_index_map = np.where(np.array(task_labels[task]) == cls)[0][0] with_in_task_offset = args.mem_size * cls_to_index_map mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter episodic_images[mem_index] = er_x episodic_labels[mem_index] = er_y_ count_cls[cls] = (count_cls[cls] + 1) % args.mem_size elif model.imp_method == 'MEGAD': if task == 0: nd_logit_mask[:] = 0 nd_logit_mask[task][task_labels[task]] = 1.0 logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} feed_dict.update(logit_mask_dict) feed_dict[model.mem_batch_size] = batch_size # Normal application of gradients _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks # Set the mask for all the previous tasks so far nd_logit_mask[:] = 0 for tt in range(task): nd_logit_mask[tt][task_labels[tt]] = 1.0 if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange(episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient ref_feed_dict = {model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.train_phase: True} logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} ref_feed_dict.update(logit_mask_dict) ref_feed_dict[model.mem_batch_size] = float(len(mem_sample_mask)) sess.run([model.store_ref_grads, model.store_ref_loss], feed_dict=ref_feed_dict) # Compute the gradient for current task and project if need be nd_logit_mask[:] = 0 nd_logit_mask[task][task_labels[task]] = 1.0 logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} feed_dict.update(logit_mask_dict) feed_dict[model.mem_batch_size] = batch_size # Training _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) # Put the batch in the ring buffer for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): cls = np.unique(np.nonzero(er_y_))[-1] # Write the example at the location pointed by count_cls[cls] cls_to_index_map = np.where(np.array(task_labels[task]) == cls)[0][0] with_in_task_offset = args.mem_size * cls_to_index_map mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter episodic_images[mem_index] = er_x episodic_labels[mem_index] = er_y_ count_cls[cls] = (count_cls[cls] + 1) % args.mem_size elif model.imp_method == 'RWALK': feed_dict[model.output_mask] = logit_mask # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Update fisher and importance score after every few iterations if (iters + 1) % model.fisher_update_after == 0: # Update the importance score using distance in riemannian manifold sess.run(model.update_big_omega_riemann) # Now that the score is updated, compute the new value for running Fisher sess.run(model.set_running_fisher) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Reset the delta_L sess.run([model.reset_small_omega]) _, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'ER': mem_filled_so_far = examples_seen_so_far if (examples_seen_so_far < episodic_mem_size) else episodic_mem_size if mem_filled_so_far < args.eps_mem_batch: er_mem_indices = np.arange(mem_filled_so_far) else: er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) nd_logit_mask[:] = 0 for tt in range(task+1): nd_logit_mask[tt][task_labels[tt]] = 1.0 logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, nd_logit_mask)} er_train_x_batch = np.concatenate((episodic_images[er_mem_indices], train_x[offset:offset+residual]), axis=0) er_train_y_batch = np.concatenate((episodic_labels[er_mem_indices], train_y[offset:offset+residual]), axis=0) feed_dict = {model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.train_phase: True} feed_dict.update(logit_mask_dict) feed_dict[model.mem_batch_size] = float(er_train_x_batch.shape[0]) _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) # Reservoir update for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) examples_seen_so_far += 1 if (iters % 100 == 0): print('Step {:d} {:.3f}'.format(iters, loss)) if (math.isnan(loss)): print('ERROR: NaNs NaNs NaNs!!!') sys.exit(0) print('\t\t\t\tTraining for Task%d done!'%(task)) if use_episodic_memory: episodic_filled_counter += args.mem_size * classes_per_task if model.imp_method == 'A-GEM': if COUNT_VIOLATONS: violation_count[task] = vc print('Task {}: Violation Count: {}'.format(task, violation_count)) sess.run(model.reset_violation_count, feed_dict=feed_dict) # Compute the inter-task updates, Fisher/ importance scores etc # Don't calculate the task updates for the last task if (task < (len(task_labels) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: model.task_updates(sess, task, task_train_images, task_labels[task]) # TODO: For MAS, should the gradients be for current task or all the previous tasks print('\t\t\t\tTask updates after Task%d done!'%(task)) if VISUALIZE_IMPORTANCE_MEASURE: if runid == 0: for i in range(len(model.fisher_diagonal_at_minima)): if i == 0: flatten_fisher = np.array(model.fisher_diagonal_at_minima[i].eval()).flatten() else: flatten_fisher = np.concatenate((flatten_fisher, np.array(model.fisher_diagonal_at_minima[i].eval()).flatten())) #flatten_fisher [flatten_fisher > 0.1] = 0.1 if args.train_single_epoch: plot_histogram(flatten_fisher, 100, '/private/home/arslanch/Dropbox/LLL_experiments/Single_Epoch/importance_vis/single_epoch/m_ewc/hist_fisher_task%s.png'%(task)) else: plot_histogram(flatten_fisher, 100, '/private/home/arslanch/Dropbox/LLL_experiments/Single_Epoch/importance_vis/single_epoch/m_ewc/hist_fisher_task%s.png'%(task)) if args.train_single_epoch and not args.cross_validate_mode: fbatch = test_task_sequence(model, sess, datasets[0]['test'], task_labels, task) print('Task: {}, Acc: {}'.format(task, fbatch)) ftask.append(fbatch) ftask = np.array(ftask) if model.imp_method == 'PNN': pnn_train_phase[:] = False pnn_train_phase[task] = True pnn_logit_mask[:] = 0 pnn_logit_mask[task][task_labels[task]] = 1.0 else: if MEASURE_PERF_ON_EPS_MEMORY: eps_mem = { 'images': episodic_images, 'labels': episodic_labels, } # Measure perf on episodic memory ftask = test_task_sequence(model, sess, eps_mem, task_labels, task, classes_per_task=classes_per_task) else: # List to store accuracy for all the tasks for the current trained model ftask = test_task_sequence(model, sess, datasets[0]['test'], task_labels, task) print('Task: {}, Acc: {}'.format(task, ftask)) # Store the accuracies computed at task T in a list evals.append(ftask) # Reset the optimizer model.reset_optimizer(sess) #-> End for loop task runs.append(np.array(evals)) # End for loop runid runs = np.array(runs) return runs, task_labels_dataset
def train_task_sequence(model, sess, args): """ Train and evaluate LLL system such that we only see a example once Args: Returns: dict A dictionary containing mean and stds for the experiment """ # List to store accuracy for each run runs = [] batch_size = args.batch_size if model.imp_method in {'A-GEM', 'MER'} or 'ER-' in model.imp_method: use_episodic_memory = True else: use_episodic_memory = False # Loop over number of runs to average over for runid in range(args.num_runs): print('\t\tRun %d:' % (runid)) # Initialize the random seeds np.random.seed(args.random_seed + runid) time_start = time.time() # Load the permute mnist dataset datasets = construct_permute_mnist(model.num_tasks) print('Data loading time: {}'.format(time.time() - time_start)) episodic_mem_size = args.mem_size * model.num_tasks * TOTAL_CLASSES # Initialize all the variables in the model sess.run(tf.global_variables_initializer()) # Run the init ops model.init_updates(sess) # List to store accuracies for a run evals = [] # List to store the classes that we have so far - used at test time test_labels = np.arange(TOTAL_CLASSES) if use_episodic_memory: # Reserve a space for episodic memory episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE]) episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) episodic_filled_counter = 0 examples_seen_so_far = 0 # Mask for softmax # Since all the classes are present in all the tasks so nothing to mask logit_mask = np.ones(TOTAL_CLASSES) if model.imp_method == 'PNN': pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES]) if COUNT_VIOLATIONS: violation_count = np.zeros(model.num_tasks) vc = 0 # Store the projection matrices for each task proj_matrices = generate_projection_matrix( model.num_tasks, feature_dim=model.subspace_proj.get_shape()[0], qr=QR) # Check the sanity of the generated matrices unit_test_projection_matrices(proj_matrices) # TODO: Temp for gradients check prev_task_grads = [] # Training loop for all the tasks for task in range(len(datasets)): print('\t\tTask %d:' % (task)) # If not the first task then restore weights from previous task if (task > 0 and model.imp_method != 'PNN'): model.restore(sess) if MULTI_TASK: if task == 0: # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = sample_weights[ perm][:args.examples_per_task] for t_ in range(1, len(datasets)): task_train_images = datasets[t_]['train']['images'] task_train_labels = datasets[t_]['train']['labels'] sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = np.concatenate( (train_x, task_train_images[perm][:args.examples_per_task]), axis=0) train_y = np.concatenate( (train_y, task_train_labels[perm][:args.examples_per_task]), axis=0) task_sample_weights = np.concatenate( (task_sample_weights, sample_weights[perm][:args.examples_per_task]), axis=0) perm = np.arange(train_x.shape[0]) np.random.shuffle(perm) train_x = train_x[perm] train_y = train_y[perm] task_sample_weights = task_sample_weights[perm] else: # Skip training for this task continue else: # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] # Assign equal weights to all the examples task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = task_sample_weights[ perm][:args.examples_per_task] print('Received {} images, {} labels at task {}'.format( train_x.shape[0], train_y.shape[0], task)) # Array to store accuracies when training for task T ftask = [] num_train_examples = train_x.shape[0] # Train a task observing sequence of data if args.train_single_epoch: num_iters = (num_train_examples + batch_size - 1) // batch_size else: num_iters = args.train_iters # Training loop for task T for iters in range(num_iters): if args.train_single_epoch and not args.cross_validate_mode: if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0): # Snapshot the current performance across all tasks after each mini-batch fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val, proj_matrices) ftask.append(fbatch) offset = (iters * batch_size) % num_train_examples if (offset + batch_size <= num_train_examples): residual = batch_size else: residual = num_train_examples - offset if model.imp_method == 'PNN': pnn_train_phase[:] = False pnn_train_phase[task] = True feed_dict = { model.x: train_x[offset:offset + batch_size], model.y_[task]: train_y[offset:offset + batch_size], model.sample_weights: task_sample_weights[offset:offset + batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.learning_rate: args.learning_rate } train_phase_dict = { m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase) } logit_mask_dict = { m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask) } feed_dict.update(train_phase_dict) feed_dict.update(logit_mask_dict) else: feed_dict = { model.x: train_x[offset:offset + batch_size], model.y_: train_y[offset:offset + batch_size], model.sample_weights: task_sample_weights[offset:offset + batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate } if model.imp_method == 'VAN': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PNN': feed_dict[model.task_id] = task _, loss = sess.run( [model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) elif model.imp_method == 'FTR_EXT': if task == 0: _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) else: _, loss = sess.run( [model.train_classifier, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'EWC': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Update fisher after every few iterations if (iters + 1) % model.fisher_update_after == 0: sess.run(model.set_running_fisher) sess.run(model.reset_tmp_fisher) _, _, loss = sess.run( [model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PI': _, _, _, loss = sess.run([ model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) elif model.imp_method == 'MAS': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PROJ-ANCHOR': if task == 0: feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) reg = 0.0 else: # Store the gradients in the orthogonal compliment feed_dict[model.subspace_proj] = np.eye( proj_matrices[task].shape[0]) - proj_matrices[task] _, reg = sess.run( [model.store_ref_grads, model.anchor_loss], feed_dict=feed_dict) feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run( [model.train_subspace_proj, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PROJ-SUBSPACE-GP': if task == 0: feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) reg = 0.0 else: # Store the gradients in the orthogonal compliment feed_dict[model.subspace_proj] = np.eye( proj_matrices[task].shape[0]) - proj_matrices[task] sess.run(model.store_ref_grads, feed_dict=feed_dict) feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run( [model.train_gp, model.gp_total_loss], feed_dict=feed_dict) elif model.imp_method == 'SUBSPACE-PROJ': feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'A-GEM': if task == 0: # Normal application of gradients _, loss = sess.run( [model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange( episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice( episodic_filled_counter, args.eps_mem_batch, replace=False ) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient sess.run(model.store_ref_grads, feed_dict={ model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate }) if COUNT_VIOLATIONS: vc, _, loss = sess.run([ model.violation_count, model.train_subseq_tasks, model.agem_loss ], feed_dict=feed_dict) else: # Compute the gradient for current task and project if need be _, loss = sess.run( [model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) # Put the batch in the ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) elif model.imp_method == 'RWALK': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Update fisher and importance score after every few iterations if (iters + 1) % model.fisher_update_after == 0: # Update the importance score using distance in riemannian manifold sess.run(model.update_big_omega_riemann) # Now that the score is updated, compute the new value for running Fisher sess.run(model.set_running_fisher) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Reset the delta_L sess.run([model.reset_small_omega]) _, _, _, _, loss = sess.run([ model.set_tmp_fisher, model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) elif model.imp_method == 'MER': mem_filled_so_far = examples_seen_so_far if ( examples_seen_so_far < episodic_mem_size ) else episodic_mem_size if mem_filled_so_far < args.eps_mem_batch: er_mem_indices = np.arange(mem_filled_so_far) else: er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) mer_episodic_x_batch, mer_episodic_y_batch = episodic_images[ er_mem_indices], episodic_labels[er_mem_indices] sess.run(model.store_theta_i_not_w) for mer_x, mer_y in zip(mer_episodic_x_batch, mer_episodic_y_batch): feed_dict = { model.x: np.expand_dims(mer_x, axis=0), model.y_: np.expand_dims(mer_y, axis=0), model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate } sess.run(model.train, feed_dict=feed_dict) feed_dict = { model.x: train_x[offset:offset + residual], model.y_: train_y[offset:offset + residual], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: MER_S * args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) sess.run(model.with_in_batch_reptile_update, feed_dict={model.mer_beta: MER_GAMMA}) # Store the examples in episodic memory using reservior sampling for er_x, er_y_ in zip(train_x[offset:offset + residual], train_y[offset:offset + residual]): update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) examples_seen_so_far += 1 elif model.imp_method == 'ER-Reservoir': mem_filled_so_far = examples_seen_so_far if ( examples_seen_so_far < episodic_mem_size ) else episodic_mem_size if mem_filled_so_far < args.eps_mem_batch: er_mem_indices = np.arange(mem_filled_so_far) else: er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) er_train_x_batch = np.concatenate( (episodic_images[er_mem_indices], train_x[offset:offset + residual]), axis=0) er_train_y_batch = np.concatenate( (episodic_labels[er_mem_indices], train_y[offset:offset + residual]), axis=0) feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) for er_x, er_y_ in zip(train_x[offset:offset + residual], train_y[offset:offset + residual]): update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) examples_seen_so_far += 1 elif model.imp_method == 'ER-Ringbuffer': mem_filled_so_far = episodic_filled_counter if ( episodic_filled_counter <= episodic_mem_size ) else episodic_mem_size er_mem_indices = np.arange(mem_filled_so_far) if ( mem_filled_so_far <= args.eps_mem_batch ) else np.random.choice( mem_filled_so_far, args.eps_mem_batch, replace=False) er_train_x_batch = np.concatenate( (episodic_images[er_mem_indices], train_x[offset:offset + residual]), axis=0) er_train_y_batch = np.concatenate( (episodic_labels[er_mem_indices], train_y[offset:offset + residual]), axis=0) feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) # Put the batch in the FIFO ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) elif model.imp_method == 'ER-SUBSPACE': # Zero out all the grads sess.run([model.reset_er_subspace_grads]) # Accumulate grads for all the tasks if task > 0: # Randomly pick a task to replay tt = np.squeeze( np.random.choice(np.arange(task), 1, replace=False)) mem_offset = tt * args.mem_size * TOTAL_CLASSES er_mem_indices = np.arange( mem_offset, mem_offset + args.mem_size * TOTAL_CLASSES) np.random.shuffle(er_mem_indices) er_train_x_batch = episodic_images[er_mem_indices] er_train_y_batch = episodic_labels[er_mem_indices] feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.task_id: task + 1, model.learning_rate: args.learning_rate } feed_dict[model.subspace_proj] = proj_matrices[tt] sess.run(model.accum_er_subspace_grads, feed_dict=feed_dict) # Train on the current task feed_dict = { model.x: train_x[offset:offset + residual], model.y_: train_y[offset:offset + residual], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.task_id: task + 1, model.learning_rate: args.learning_rate } feed_dict[model.subspace_proj] = proj_matrices[task] if args.maintain_orthogonality: _, loss = sess.run( [model.accum_er_subspace_grads, model.reg_loss], feed_dict=feed_dict) sess.run(model.train_stiefel, feed_dict={ model.learning_rate: args.learning_rate }) else: _, _, loss = sess.run([ model.train_er_subspace, model.accum_er_subspace_grads, model.reg_loss ], feed_dict=feed_dict) # Put the batch in the FIFO ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) if (iters % 100 == 0): print('Step {:d} {:.3f}'.format(iters, loss)) #print('Step {:d}\t CE: {:.3f}\t Reg: {:.3f}\t TL: {:.3f}'.format(iters, entropy, reg, loss)) #print('Step {:d}\t Reg: {:.3f}\t TL: {:.3f}'.format(iters, reg, loss)) if (math.isnan(loss)): print('ERROR: NaNs NaNs Nans!!!') sys.exit(0) print('\t\t\t\tTraining for Task%d done!' % (task)) if model.imp_method == 'SUBSPACE-PROJ' and GRAD_CHECK: # TODO: Compute the average gradient for the task at \theta^* bbatch_size = 100 grad_sum = [] for iiters in range(train_x.shape[0] // bbatch_size): offset = iiters * bbatch_size feed_dict = { model.x: train_x[offset:offset + bbatch_size], model.y_: train_y[offset:offset + bbatch_size], model.keep_prob: 1.0, model.train_phase: True, model.subspace_proj: proj_matrices[task], model.output_mask: logit_mask, model.learning_rate: args.learning_rate } grad_vars, train_vars = sess.run( [model.reg_gradients_vars, model.trainable_vars], feed_dict=feed_dict) for v in range(len(train_vars)): if iiters == 0: grad_sum.append(grad_vars[v][0]) else: grad_sum[v] += (grad_vars[v][0] - grad_sum[v]) / iiters prev_task_grads.append(grad_sum) # Upaate the episodic memory filled counter if use_episodic_memory: episodic_filled_counter += args.mem_size * TOTAL_CLASSES if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS: violation_count[task] = vc print('Task {}: Violation Count: {}'.format( task, violation_count)) sess.run(model.reset_violation_count, feed_dict=feed_dict) # Compute the inter-task updates, Fisher/ importance scores etc # Don't calculate the task updates for the last task if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES)) print('\t\t\t\tTask updates after Task%d done!' % (task)) if args.train_single_epoch and not args.cross_validate_mode: fbatch = test_task_sequence(model, sess, datasets, False, proj_matrices) ftask.append(fbatch) ftask = np.array(ftask) else: if MEASURE_PERF_ON_EPS_MEMORY: eps_mem = { 'images': episodic_images, 'labels': episodic_labels, } # Measure perf on episodic memory ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val, proj_matrices) else: # List to store accuracy for all the tasks for the current trained model ftask = test_task_sequence(model, sess, datasets, args.online_cross_val, proj_matrices) print('Task: {}, Acc: {}'.format(task, ftask)) # Store the accuracies computed at task T in a list evals.append(ftask) # Reset the optimizer model.reset_optimizer(sess) #-> End for loop task runs.append(np.array(evals)) # End for loop runid runs = np.array(runs) return runs
def train_task_sequence(model, sess, args): """ Train and evaluate LLL system such that we only see a example once Args: Returns: dict A dictionary containing mean and stds for the experiment """ # List to store accuracy for each run runs = [] batch_size = args.batch_size if model.imp_method == 'A-GEM' or model.imp_method == 'ER': use_episodic_memory = True else: use_episodic_memory = False # Loop over number of runs to average over for runid in range(args.num_runs): print('\t\tRun %d:'%(runid)) # Initialize the random seeds np.random.seed(args.random_seed+runid) # Load the permute mnist dataset datasets = construct_permute_mnist(model.num_tasks) episodic_mem_size = args.mem_size*model.num_tasks*TOTAL_CLASSES # Initialize all the variables in the model sess.run(tf.global_variables_initializer()) # Run the init ops model.init_updates(sess) # List to store accuracies for a run evals = [] # List to store the classes that we have so far - used at test time test_labels = np.arange(TOTAL_CLASSES) if use_episodic_memory: # Reserve a space for episodic memory episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE]) episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) episodic_filled_counter = 0 examples_seen_so_far = 0 # Mask for softmax # Since all the classes are present in all the tasks so nothing to mask logit_mask = np.ones(TOTAL_CLASSES) if model.imp_method == 'PNN': pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES]) if COUNT_VIOLATIONS: violation_count = np.zeros(model.num_tasks) vc = 0 # Training loop for all the tasks for task in range(len(datasets)): print('\t\tTask %d:'%(task)) # If not the first task then restore weights from previous task if(task > 0 and model.imp_method != 'PNN'): model.restore(sess) # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] # If multi_task is set the train using datasets of all the tasks if MULTI_TASK: if task == 0: for t_ in range(1, len(datasets)): task_train_images = np.concatenate((task_train_images, datasets[t_]['train']['images']), axis=0) task_train_labels = np.concatenate((task_train_labels, datasets[t_]['train']['labels']), axis=0) else: # Skip training for this task continue # Assign equal weights to all the examples task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(list(perm)) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = task_sample_weights[perm][:args.examples_per_task] print('Received {} images, {} labels at task {}'.format(train_x.shape[0], train_y.shape[0], task)) # Array to store accuracies when training for task T ftask = [] num_train_examples = train_x.shape[0] # Train a task observing sequence of data if args.train_single_epoch: num_iters = num_train_examples // batch_size else: num_iters = args.train_iters # Training loop for task T for iters in range(num_iters): if args.train_single_epoch and not args.cross_validate_mode: if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0): # Snapshot the current performance across all tasks after each mini-batch fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val) ftask.append(fbatch) offset = (iters * batch_size) % (num_train_examples - batch_size) residual = batch_size if model.imp_method == 'PNN': pnn_train_phase[:] = False pnn_train_phase[task] = True feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_[task]: train_y[offset:offset+batch_size], model.sample_weights: task_sample_weights[offset:offset+batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0} train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} feed_dict.update(train_phase_dict) feed_dict.update(logit_mask_dict) else: feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_: train_y[offset:offset+batch_size], model.sample_weights: task_sample_weights[offset:offset+batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True} if model.imp_method == 'VAN': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PNN': feed_dict[model.task_id] = task _, loss = sess.run([model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) elif model.imp_method == 'FTR_EXT': if task == 0: _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) else: _, loss = sess.run([model.train_classifier, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'EWC': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Update fisher after every few iterations if (iters + 1) % model.fisher_update_after == 0: sess.run(model.set_running_fisher) sess.run(model.reset_tmp_fisher) _, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PI': _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'MAS': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'A-GEM': if task == 0: # Normal application of gradients _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange(episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient sess.run(model.store_ref_grads, feed_dict={model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True}) if COUNT_VIOLATIONS: vc, _, loss = sess.run([model.violation_count, model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) else: # Compute the gradient for current task and project if need be _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) # Put the batch in the ring buffer for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): cls = np.unique(np.nonzero(er_y_))[-1] # Write the example at the location pointed by count_cls[cls] cls_to_index_map = cls with_in_task_offset = args.mem_size * cls_to_index_map mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter episodic_images[mem_index] = er_x episodic_labels[mem_index] = er_y_ count_cls[cls] = (count_cls[cls] + 1) % args.mem_size elif model.imp_method == 'RWALK': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Update fisher and importance score after every few iterations if (iters + 1) % model.fisher_update_after == 0: # Update the importance score using distance in riemannian manifold sess.run(model.update_big_omega_riemann) # Now that the score is updated, compute the new value for running Fisher sess.run(model.set_running_fisher) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Reset the delta_L sess.run([model.reset_small_omega]) _, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'ER': mem_filled_so_far = examples_seen_so_far if (examples_seen_so_far < episodic_mem_size) else episodic_mem_size if mem_filled_so_far < args.eps_mem_batch: er_mem_indices = np.arange(mem_filled_so_far) else: er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) # Train on a batch of episodic memory first er_train_x_batch = np.concatenate((episodic_images[er_mem_indices], train_x[offset:offset+residual]), axis=0) er_train_y_batch = np.concatenate((episodic_labels[er_mem_indices], train_y[offset:offset+residual]), axis=0) feed_dict = {model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True} _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) examples_seen_so_far += 1 if (iters % 100 == 0): print('Step {:d} {:.3f}'.format(iters, loss)) if (math.isnan(loss)): print('ERROR: NaNs NaNs Nans!!!') sys.exit(0) print('\t\t\t\tTraining for Task%d done!'%(task)) # Upaate the episodic memory filled counter if use_episodic_memory: episodic_filled_counter += args.mem_size * TOTAL_CLASSES if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS: violation_count[task] = vc print('Task {}: Violation Count: {}'.format(task, violation_count)) sess.run(model.reset_violation_count, feed_dict=feed_dict) # Compute the inter-task updates, Fisher/ importance scores etc # Don't calculate the task updates for the last task if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES)) print('\t\t\t\tTask updates after Task%d done!'%(task)) if args.train_single_epoch and not args.cross_validate_mode: fbatch = test_task_sequence(model, sess, datasets, False) ftask.append(fbatch) ftask = np.array(ftask) else: if MEASURE_PERF_ON_EPS_MEMORY: eps_mem = { 'images': episodic_images, 'labels': episodic_labels, } # Measure perf on episodic memory ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val) else: # List to store accuracy for all the tasks for the current trained model ftask = test_task_sequence(model, sess, datasets, args.online_cross_val) # Store the accuracies computed at task T in a list evals.append(ftask) # Reset the optimizer model.reset_optimizer(sess) #-> End for loop task runs.append(np.array(evals)) # End for loop runid runs = np.array(runs) return runs