def train_task_sequence(model, sess, args): """ Train and evaluate LLL system such that we only see a example once Args: Returns: dict A dictionary containing mean and stds for the experiment """ # List to store accuracy for each run runs = [] batch_size = args.batch_size if model.imp_method in {'A-GEM', 'MER'} or 'ER-' in model.imp_method: use_episodic_memory = True else: use_episodic_memory = False # Loop over number of runs to average over for runid in range(args.num_runs): print('\t\tRun %d:' % (runid)) # Initialize the random seeds np.random.seed(args.random_seed + runid) time_start = time.time() # Load the permute mnist dataset datasets = construct_permute_mnist(model.num_tasks) print('Data loading time: {}'.format(time.time() - time_start)) episodic_mem_size = args.mem_size * model.num_tasks * TOTAL_CLASSES # Initialize all the variables in the model sess.run(tf.global_variables_initializer()) # Run the init ops model.init_updates(sess) # List to store accuracies for a run evals = [] # List to store the classes that we have so far - used at test time test_labels = np.arange(TOTAL_CLASSES) if use_episodic_memory: # Reserve a space for episodic memory episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE]) episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) episodic_filled_counter = 0 examples_seen_so_far = 0 # Mask for softmax # Since all the classes are present in all the tasks so nothing to mask logit_mask = np.ones(TOTAL_CLASSES) if model.imp_method == 'PNN': pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES]) if COUNT_VIOLATIONS: violation_count = np.zeros(model.num_tasks) vc = 0 # Store the projection matrices for each task proj_matrices = generate_projection_matrix( model.num_tasks, feature_dim=model.subspace_proj.get_shape()[0], qr=QR) # Check the sanity of the generated matrices unit_test_projection_matrices(proj_matrices) # TODO: Temp for gradients check prev_task_grads = [] # Training loop for all the tasks for task in range(len(datasets)): print('\t\tTask %d:' % (task)) # If not the first task then restore weights from previous task if (task > 0 and model.imp_method != 'PNN'): model.restore(sess) if MULTI_TASK: if task == 0: # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = sample_weights[ perm][:args.examples_per_task] for t_ in range(1, len(datasets)): task_train_images = datasets[t_]['train']['images'] task_train_labels = datasets[t_]['train']['labels'] sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = np.concatenate( (train_x, task_train_images[perm][:args.examples_per_task]), axis=0) train_y = np.concatenate( (train_y, task_train_labels[perm][:args.examples_per_task]), axis=0) task_sample_weights = np.concatenate( (task_sample_weights, sample_weights[perm][:args.examples_per_task]), axis=0) perm = np.arange(train_x.shape[0]) np.random.shuffle(perm) train_x = train_x[perm] train_y = train_y[perm] task_sample_weights = task_sample_weights[perm] else: # Skip training for this task continue else: # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] # Assign equal weights to all the examples task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = task_sample_weights[ perm][:args.examples_per_task] print('Received {} images, {} labels at task {}'.format( train_x.shape[0], train_y.shape[0], task)) # Array to store accuracies when training for task T ftask = [] num_train_examples = train_x.shape[0] # Train a task observing sequence of data if args.train_single_epoch: num_iters = (num_train_examples + batch_size - 1) // batch_size else: num_iters = args.train_iters # Training loop for task T for iters in range(num_iters): if args.train_single_epoch and not args.cross_validate_mode: if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0): # Snapshot the current performance across all tasks after each mini-batch fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val, proj_matrices) ftask.append(fbatch) offset = (iters * batch_size) % num_train_examples if (offset + batch_size <= num_train_examples): residual = batch_size else: residual = num_train_examples - offset if model.imp_method == 'PNN': pnn_train_phase[:] = False pnn_train_phase[task] = True feed_dict = { model.x: train_x[offset:offset + batch_size], model.y_[task]: train_y[offset:offset + batch_size], model.sample_weights: task_sample_weights[offset:offset + batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.learning_rate: args.learning_rate } train_phase_dict = { m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase) } logit_mask_dict = { m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask) } feed_dict.update(train_phase_dict) feed_dict.update(logit_mask_dict) else: feed_dict = { model.x: train_x[offset:offset + batch_size], model.y_: train_y[offset:offset + batch_size], model.sample_weights: task_sample_weights[offset:offset + batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate } if model.imp_method == 'VAN': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PNN': feed_dict[model.task_id] = task _, loss = sess.run( [model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) elif model.imp_method == 'FTR_EXT': if task == 0: _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) else: _, loss = sess.run( [model.train_classifier, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'EWC': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Update fisher after every few iterations if (iters + 1) % model.fisher_update_after == 0: sess.run(model.set_running_fisher) sess.run(model.reset_tmp_fisher) _, _, loss = sess.run( [model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PI': _, _, _, loss = sess.run([ model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) elif model.imp_method == 'MAS': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PROJ-ANCHOR': if task == 0: feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) reg = 0.0 else: # Store the gradients in the orthogonal compliment feed_dict[model.subspace_proj] = np.eye( proj_matrices[task].shape[0]) - proj_matrices[task] _, reg = sess.run( [model.store_ref_grads, model.anchor_loss], feed_dict=feed_dict) feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run( [model.train_subspace_proj, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PROJ-SUBSPACE-GP': if task == 0: feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) reg = 0.0 else: # Store the gradients in the orthogonal compliment feed_dict[model.subspace_proj] = np.eye( proj_matrices[task].shape[0]) - proj_matrices[task] sess.run(model.store_ref_grads, feed_dict=feed_dict) feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run( [model.train_gp, model.gp_total_loss], feed_dict=feed_dict) elif model.imp_method == 'SUBSPACE-PROJ': feed_dict[model.subspace_proj] = proj_matrices[task] _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'A-GEM': if task == 0: # Normal application of gradients _, loss = sess.run( [model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange( episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice( episodic_filled_counter, args.eps_mem_batch, replace=False ) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient sess.run(model.store_ref_grads, feed_dict={ model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate }) if COUNT_VIOLATIONS: vc, _, loss = sess.run([ model.violation_count, model.train_subseq_tasks, model.agem_loss ], feed_dict=feed_dict) else: # Compute the gradient for current task and project if need be _, loss = sess.run( [model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) # Put the batch in the ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) elif model.imp_method == 'RWALK': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Update fisher and importance score after every few iterations if (iters + 1) % model.fisher_update_after == 0: # Update the importance score using distance in riemannian manifold sess.run(model.update_big_omega_riemann) # Now that the score is updated, compute the new value for running Fisher sess.run(model.set_running_fisher) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Reset the delta_L sess.run([model.reset_small_omega]) _, _, _, _, loss = sess.run([ model.set_tmp_fisher, model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) elif model.imp_method == 'MER': mem_filled_so_far = examples_seen_so_far if ( examples_seen_so_far < episodic_mem_size ) else episodic_mem_size if mem_filled_so_far < args.eps_mem_batch: er_mem_indices = np.arange(mem_filled_so_far) else: er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) mer_episodic_x_batch, mer_episodic_y_batch = episodic_images[ er_mem_indices], episodic_labels[er_mem_indices] sess.run(model.store_theta_i_not_w) for mer_x, mer_y in zip(mer_episodic_x_batch, mer_episodic_y_batch): feed_dict = { model.x: np.expand_dims(mer_x, axis=0), model.y_: np.expand_dims(mer_y, axis=0), model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate } sess.run(model.train, feed_dict=feed_dict) feed_dict = { model.x: train_x[offset:offset + residual], model.y_: train_y[offset:offset + residual], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: MER_S * args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) sess.run(model.with_in_batch_reptile_update, feed_dict={model.mer_beta: MER_GAMMA}) # Store the examples in episodic memory using reservior sampling for er_x, er_y_ in zip(train_x[offset:offset + residual], train_y[offset:offset + residual]): update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) examples_seen_so_far += 1 elif model.imp_method == 'ER-Reservoir': mem_filled_so_far = examples_seen_so_far if ( examples_seen_so_far < episodic_mem_size ) else episodic_mem_size if mem_filled_so_far < args.eps_mem_batch: er_mem_indices = np.arange(mem_filled_so_far) else: er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) er_train_x_batch = np.concatenate( (episodic_images[er_mem_indices], train_x[offset:offset + residual]), axis=0) er_train_y_batch = np.concatenate( (episodic_labels[er_mem_indices], train_y[offset:offset + residual]), axis=0) feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True, model.learning_rate: args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) for er_x, er_y_ in zip(train_x[offset:offset + residual], train_y[offset:offset + residual]): update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) examples_seen_so_far += 1 elif model.imp_method == 'ER-Ringbuffer': mem_filled_so_far = episodic_filled_counter if ( episodic_filled_counter <= episodic_mem_size ) else episodic_mem_size er_mem_indices = np.arange(mem_filled_so_far) if ( mem_filled_so_far <= args.eps_mem_batch ) else np.random.choice( mem_filled_so_far, args.eps_mem_batch, replace=False) er_train_x_batch = np.concatenate( (episodic_images[er_mem_indices], train_x[offset:offset + residual]), axis=0) er_train_y_batch = np.concatenate( (episodic_labels[er_mem_indices], train_y[offset:offset + residual]), axis=0) feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) # Put the batch in the FIFO ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) elif model.imp_method == 'ER-SUBSPACE': # Zero out all the grads sess.run([model.reset_er_subspace_grads]) # Accumulate grads for all the tasks if task > 0: # Randomly pick a task to replay tt = np.squeeze( np.random.choice(np.arange(task), 1, replace=False)) mem_offset = tt * args.mem_size * TOTAL_CLASSES er_mem_indices = np.arange( mem_offset, mem_offset + args.mem_size * TOTAL_CLASSES) np.random.shuffle(er_mem_indices) er_train_x_batch = episodic_images[er_mem_indices] er_train_y_batch = episodic_labels[er_mem_indices] feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.task_id: task + 1, model.learning_rate: args.learning_rate } feed_dict[model.subspace_proj] = proj_matrices[tt] sess.run(model.accum_er_subspace_grads, feed_dict=feed_dict) # Train on the current task feed_dict = { model.x: train_x[offset:offset + residual], model.y_: train_y[offset:offset + residual], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.task_id: task + 1, model.learning_rate: args.learning_rate } feed_dict[model.subspace_proj] = proj_matrices[task] if args.maintain_orthogonality: _, loss = sess.run( [model.accum_er_subspace_grads, model.reg_loss], feed_dict=feed_dict) sess.run(model.train_stiefel, feed_dict={ model.learning_rate: args.learning_rate }) else: _, _, loss = sess.run([ model.train_er_subspace, model.accum_er_subspace_grads, model.reg_loss ], feed_dict=feed_dict) # Put the batch in the FIFO ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) if (iters % 100 == 0): print('Step {:d} {:.3f}'.format(iters, loss)) #print('Step {:d}\t CE: {:.3f}\t Reg: {:.3f}\t TL: {:.3f}'.format(iters, entropy, reg, loss)) #print('Step {:d}\t Reg: {:.3f}\t TL: {:.3f}'.format(iters, reg, loss)) if (math.isnan(loss)): print('ERROR: NaNs NaNs Nans!!!') sys.exit(0) print('\t\t\t\tTraining for Task%d done!' % (task)) if model.imp_method == 'SUBSPACE-PROJ' and GRAD_CHECK: # TODO: Compute the average gradient for the task at \theta^* bbatch_size = 100 grad_sum = [] for iiters in range(train_x.shape[0] // bbatch_size): offset = iiters * bbatch_size feed_dict = { model.x: train_x[offset:offset + bbatch_size], model.y_: train_y[offset:offset + bbatch_size], model.keep_prob: 1.0, model.train_phase: True, model.subspace_proj: proj_matrices[task], model.output_mask: logit_mask, model.learning_rate: args.learning_rate } grad_vars, train_vars = sess.run( [model.reg_gradients_vars, model.trainable_vars], feed_dict=feed_dict) for v in range(len(train_vars)): if iiters == 0: grad_sum.append(grad_vars[v][0]) else: grad_sum[v] += (grad_vars[v][0] - grad_sum[v]) / iiters prev_task_grads.append(grad_sum) # Upaate the episodic memory filled counter if use_episodic_memory: episodic_filled_counter += args.mem_size * TOTAL_CLASSES if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS: violation_count[task] = vc print('Task {}: Violation Count: {}'.format( task, violation_count)) sess.run(model.reset_violation_count, feed_dict=feed_dict) # Compute the inter-task updates, Fisher/ importance scores etc # Don't calculate the task updates for the last task if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES)) print('\t\t\t\tTask updates after Task%d done!' % (task)) if args.train_single_epoch and not args.cross_validate_mode: fbatch = test_task_sequence(model, sess, datasets, False, proj_matrices) ftask.append(fbatch) ftask = np.array(ftask) else: if MEASURE_PERF_ON_EPS_MEMORY: eps_mem = { 'images': episodic_images, 'labels': episodic_labels, } # Measure perf on episodic memory ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val, proj_matrices) else: # List to store accuracy for all the tasks for the current trained model ftask = test_task_sequence(model, sess, datasets, args.online_cross_val, proj_matrices) print('Task: {}, Acc: {}'.format(task, ftask)) # Store the accuracies computed at task T in a list evals.append(ftask) # Reset the optimizer model.reset_optimizer(sess) #-> End for loop task runs.append(np.array(evals)) # End for loop runid runs = np.array(runs) return runs
def train_task_sequence(model, sess, args): """ Train and evaluate LLL system such that we only see a example once Args: Returns: dict A dictionary containing mean and stds for the experiment """ # List to store accuracy for each run runs = [] batch_size = args.batch_size if model.imp_method == 'A-GEM' or model.imp_method == 'ER': use_episodic_memory = True else: use_episodic_memory = False # Loop over number of runs to average over for runid in range(args.num_runs): print('\t\tRun %d:'%(runid)) # Initialize the random seeds np.random.seed(args.random_seed+runid) # Load the permute mnist dataset datasets = construct_permute_mnist(model.num_tasks) episodic_mem_size = args.mem_size*model.num_tasks*TOTAL_CLASSES # Initialize all the variables in the model sess.run(tf.global_variables_initializer()) # Run the init ops model.init_updates(sess) # List to store accuracies for a run evals = [] # List to store the classes that we have so far - used at test time test_labels = np.arange(TOTAL_CLASSES) if use_episodic_memory: # Reserve a space for episodic memory episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE]) episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) episodic_filled_counter = 0 examples_seen_so_far = 0 # Mask for softmax # Since all the classes are present in all the tasks so nothing to mask logit_mask = np.ones(TOTAL_CLASSES) if model.imp_method == 'PNN': pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES]) if COUNT_VIOLATIONS: violation_count = np.zeros(model.num_tasks) vc = 0 # Training loop for all the tasks for task in range(len(datasets)): print('\t\tTask %d:'%(task)) # If not the first task then restore weights from previous task if(task > 0 and model.imp_method != 'PNN'): model.restore(sess) # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] # If multi_task is set the train using datasets of all the tasks if MULTI_TASK: if task == 0: for t_ in range(1, len(datasets)): task_train_images = np.concatenate((task_train_images, datasets[t_]['train']['images']), axis=0) task_train_labels = np.concatenate((task_train_labels, datasets[t_]['train']['labels']), axis=0) else: # Skip training for this task continue # Assign equal weights to all the examples task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(list(perm)) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = task_sample_weights[perm][:args.examples_per_task] print('Received {} images, {} labels at task {}'.format(train_x.shape[0], train_y.shape[0], task)) # Array to store accuracies when training for task T ftask = [] num_train_examples = train_x.shape[0] # Train a task observing sequence of data if args.train_single_epoch: num_iters = num_train_examples // batch_size else: num_iters = args.train_iters # Training loop for task T for iters in range(num_iters): if args.train_single_epoch and not args.cross_validate_mode: if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0): # Snapshot the current performance across all tasks after each mini-batch fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val) ftask.append(fbatch) offset = (iters * batch_size) % (num_train_examples - batch_size) residual = batch_size if model.imp_method == 'PNN': pnn_train_phase[:] = False pnn_train_phase[task] = True feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_[task]: train_y[offset:offset+batch_size], model.sample_weights: task_sample_weights[offset:offset+batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0} train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)} logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)} feed_dict.update(train_phase_dict) feed_dict.update(logit_mask_dict) else: feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_: train_y[offset:offset+batch_size], model.sample_weights: task_sample_weights[offset:offset+batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True} if model.imp_method == 'VAN': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PNN': feed_dict[model.task_id] = task _, loss = sess.run([model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) elif model.imp_method == 'FTR_EXT': if task == 0: _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) else: _, loss = sess.run([model.train_classifier, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'EWC': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Update fisher after every few iterations if (iters + 1) % model.fisher_update_after == 0: sess.run(model.set_running_fisher) sess.run(model.reset_tmp_fisher) _, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PI': _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'MAS': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'A-GEM': if task == 0: # Normal application of gradients _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange(episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient sess.run(model.store_ref_grads, feed_dict={model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True}) if COUNT_VIOLATIONS: vc, _, loss = sess.run([model.violation_count, model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) else: # Compute the gradient for current task and project if need be _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) # Put the batch in the ring buffer for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): cls = np.unique(np.nonzero(er_y_))[-1] # Write the example at the location pointed by count_cls[cls] cls_to_index_map = cls with_in_task_offset = args.mem_size * cls_to_index_map mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter episodic_images[mem_index] = er_x episodic_labels[mem_index] = er_y_ count_cls[cls] = (count_cls[cls] + 1) % args.mem_size elif model.imp_method == 'RWALK': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Update fisher and importance score after every few iterations if (iters + 1) % model.fisher_update_after == 0: # Update the importance score using distance in riemannian manifold sess.run(model.update_big_omega_riemann) # Now that the score is updated, compute the new value for running Fisher sess.run(model.set_running_fisher) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Reset the delta_L sess.run([model.reset_small_omega]) _, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'ER': mem_filled_so_far = examples_seen_so_far if (examples_seen_so_far < episodic_mem_size) else episodic_mem_size if mem_filled_so_far < args.eps_mem_batch: er_mem_indices = np.arange(mem_filled_so_far) else: er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) # Train on a batch of episodic memory first er_train_x_batch = np.concatenate((episodic_images[er_mem_indices], train_x[offset:offset+residual]), axis=0) er_train_y_batch = np.concatenate((episodic_labels[er_mem_indices], train_y[offset:offset+residual]), axis=0) feed_dict = {model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True} _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]): update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) examples_seen_so_far += 1 if (iters % 100 == 0): print('Step {:d} {:.3f}'.format(iters, loss)) if (math.isnan(loss)): print('ERROR: NaNs NaNs Nans!!!') sys.exit(0) print('\t\t\t\tTraining for Task%d done!'%(task)) # Upaate the episodic memory filled counter if use_episodic_memory: episodic_filled_counter += args.mem_size * TOTAL_CLASSES if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS: violation_count[task] = vc print('Task {}: Violation Count: {}'.format(task, violation_count)) sess.run(model.reset_violation_count, feed_dict=feed_dict) # Compute the inter-task updates, Fisher/ importance scores etc # Don't calculate the task updates for the last task if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES)) print('\t\t\t\tTask updates after Task%d done!'%(task)) if args.train_single_epoch and not args.cross_validate_mode: fbatch = test_task_sequence(model, sess, datasets, False) ftask.append(fbatch) ftask = np.array(ftask) else: if MEASURE_PERF_ON_EPS_MEMORY: eps_mem = { 'images': episodic_images, 'labels': episodic_labels, } # Measure perf on episodic memory ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val) else: # List to store accuracy for all the tasks for the current trained model ftask = test_task_sequence(model, sess, datasets, args.online_cross_val) # Store the accuracies computed at task T in a list evals.append(ftask) # Reset the optimizer model.reset_optimizer(sess) #-> End for loop task runs.append(np.array(evals)) # End for loop runid runs = np.array(runs) return runs
def train_task_sequence(model, sess, cross_validate_mode, train_single_epoch, eval_single_head, do_sampling, is_herding, mem_per_class, train_iters, batch_size, num_runs, online_cross_val, random_seed): """ Train and evaluate LLL system such that we only see a example once Args: Returns: dict A dictionary containing mean and stds for the experiment """ # List to store accuracy for each run runs = [] # Loop over number of runs to average over for runid in range(num_runs): print('\t\tRun %d:' % (runid)) # Initialize the random seeds np.random.seed(random_seed + runid) # Load the permute mnist dataset datasets = construct_permute_mnist(model.num_tasks) episodic_mem_size = mem_per_class * model.num_tasks * TOTAL_CLASSES # Initialize all the variables in the model sess.run(tf.global_variables_initializer()) # Run the init ops model.init_updates(sess) # List to store accuracies for a run evals = [] # List to store the classes that we have so far - used at test time test_labels = np.arange(TOTAL_CLASSES) if model.imp_method == 'S-GEM': # List to store the episodic memories of the previous tasks task_based_memory = [] if model.imp_method == 'A-GEM': # Reserve a space for episodic memory episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE]) episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) episodic_filled_counter = 0 if do_sampling: # List to store important samples from the previous tasks last_task_x = None last_task_y_ = None # Mask for softmax # Since all the classes are present in all the tasks so nothing to mask logit_mask = np.ones(TOTAL_CLASSES) if model.imp_method == 'PNN': pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool) pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES]) if COUNT_VIOLATIONS: violation_count = np.zeros(model.num_tasks) vc = 0 # Training loop for all the tasks for task in range(len(datasets)): print('\t\tTask %d:' % (task)) # If not the first task then restore weights from previous task if (task > 0 and model.imp_method != 'PNN'): model.restore(sess) # If sampling flag is set append the previous datasets if (do_sampling and task > 0): task_train_images, task_train_labels = concatenate_datasets( datasets[task]['train']['images'], datasets[task]['train']['labels'], last_task_x, last_task_y_) else: # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] # If multi_task is set the train using datasets of all the tasks if MULTI_TASK: if task == 0: for t_ in range(1, len(datasets)): task_train_images = np.concatenate( (task_train_images, datasets[t_]['train']['images']), axis=0) task_train_labels = np.concatenate( (task_train_labels, datasets[t_]['train']['labels']), axis=0) else: # Skip training for this task continue print('Received {} images, {} labels at task {}'.format( task_train_images.shape[0], task_train_labels.shape[0], task)) # Declare variables to store sample importance if sampling flag is set if do_sampling: # Get the sample weighting task_sample_weights = get_sample_weights( task_train_labels, test_labels) else: # Assign equal weights to all the examples task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) num_train_examples = task_train_images.shape[0] # Train a task observing sequence of data if train_single_epoch: num_iters = num_train_examples // batch_size else: num_iters = train_iters # Randomly suffle the training examples perm = np.arange(num_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm] train_y = task_train_labels[perm] task_sample_weights = task_sample_weights[perm] # Array to store accuracies when training for task T ftask = [] # Training loop for task T for iters in range(num_iters): if train_single_epoch and not cross_validate_mode: if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0): # Snapshot the current performance across all tasks after each mini-batch fbatch = test_task_sequence( model, sess, datasets, online_cross_val, eval_single_head=eval_single_head) ftask.append(fbatch) offset = (iters * batch_size) % (num_train_examples - batch_size) if model.imp_method == 'PNN': pnn_train_phase[:] = False pnn_train_phase[task] = True feed_dict = { model.x: train_x[offset:offset + batch_size], model.y_[task]: train_y[offset:offset + batch_size], model.sample_weights: task_sample_weights[offset:offset + batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0 } train_phase_dict = { m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase) } logit_mask_dict = { m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask) } feed_dict.update(train_phase_dict) feed_dict.update(logit_mask_dict) else: feed_dict = { model.x: train_x[offset:offset + batch_size], model.y_: train_y[offset:offset + batch_size], model.sample_weights: task_sample_weights[offset:offset + batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True } if model.imp_method == 'VAN': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PNN': feed_dict[model.task_id] = task _, loss = sess.run( [model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict) elif model.imp_method == 'FTR_EXT': if task == 0: _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) else: _, loss = sess.run( [model.train_classifier, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'EWC': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Update fisher after every few iterations if (iters + 1) % model.fisher_update_after == 0: sess.run(model.set_running_fisher) sess.run(model.reset_tmp_fisher) _, _, loss = sess.run( [model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PI': _, _, _, loss = sess.run([ model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) elif model.imp_method == 'MAS': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'S-GEM': if task == 0: # Normal application of gradients _, loss = sess.run( [model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: # Randomly sample a task from the previous tasks prev_task = np.random.randint(0, task) # Store the reference gradient sess.run(model.store_ref_grads, feed_dict={ model.x: task_based_memory[prev_task]['images'], model.y_: task_based_memory[prev_task]['labels'], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True }) # Compute the gradient for current task and project if need be _, loss = sess.run( [model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) elif model.imp_method == 'A-GEM': if task == 0: # Normal application of gradients _, loss = sess.run( [model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks if KEEP_EPISODIC_MEMORY_FULL: mem_sample_mask = np.random.choice( episodic_mem_size, EPS_MEM_BATCH_SIZE, replace=False ) # Sample without replacement so that we don't sample an example more than once else: if episodic_filled_counter <= EPS_MEM_BATCH_SIZE: mem_sample_mask = np.arange( episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice( episodic_filled_counter, EPS_MEM_BATCH_SIZE, replace=False ) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient sess.run(model.store_ref_grads, feed_dict={ model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True }) if COUNT_VIOLATIONS: vc, _, loss = sess.run([ model.violation_count, model.train_subseq_tasks, model.agem_loss ], feed_dict=feed_dict) else: # Compute the gradient for current task and project if need be _, loss = sess.run( [model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) elif model.imp_method == 'RWALK': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Update fisher and importance score after every few iterations if (iters + 1) % model.fisher_update_after == 0: # Update the importance score using distance in riemannian manifold sess.run(model.update_big_omega_riemann) # Now that the score is updated, compute the new value for running Fisher sess.run(model.set_running_fisher) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Reset the delta_L sess.run([model.reset_small_omega]) _, _, _, _, loss = sess.run([ model.set_tmp_fisher, model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) if (iters % 500 == 0): print('Step {:d} {:.3f}'.format(iters, loss)) if (math.isnan(loss)): print('ERROR: NaNs NaNs Nans!!!') sys.exit(0) print('\t\t\t\tTraining for Task%d done!' % (task)) if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS: violation_count[task] = vc print('Task {}: Violation Count: {}'.format( task, violation_count)) sess.run(model.reset_violation_count, feed_dict=feed_dict) # Compute the inter-task updates, Fisher/ importance scores etc # Don't calculate the task updates for the last task if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES)) print('\t\t\t\tTask updates after Task%d done!' % (task)) # If importance method is '*-GEM' then store the episodic memory for the task if 'GEM' in model.imp_method: data_to_sample_from = { 'images': task_train_images, 'labels': task_train_labels, } if model.imp_method == 'S-GEM': # Get the important samples from the current task if is_herding: # Sampling based on MoF # Compute the features of training data features_dim = model.image_feature_dim features = np.zeros( [num_train_examples, features_dim]) samples_at_a_time = 100 for i in range(num_train_examples // samples_at_a_time): offset = i * samples_at_a_time features[offset:offset + samples_at_a_time] = sess.run( model.features, feed_dict={ model.x: task_train_images[ offset:offset + samples_at_a_time], model.y_: task_train_labels[ offset:offset + samples_at_a_time], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: False }) imp_images, imp_labels = sample_from_dataset_icarl( data_to_sample_from, features, np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS) else: # Random sampling # Do the uniform sampling importance_array = np.ones(num_train_examples, dtype=np.float32) imp_images, imp_labels = sample_from_dataset( data_to_sample_from, importance_array, np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS) task_memory = { 'images': deepcopy(imp_images), 'labels': deepcopy(imp_labels), } task_based_memory.append(task_memory) elif model.imp_method == 'A-GEM': if is_herding: # Sampling based on MoF # Compute the features of training data features_dim = model.image_feature_dim features = np.zeros( [num_train_examples, features_dim]) samples_at_a_time = 100 for i in range(num_train_examples // samples_at_a_time): offset = i * samples_at_a_time features[offset:offset + samples_at_a_time] = sess.run( model.features, feed_dict={ model.x: task_train_images[ offset:offset + samples_at_a_time], model.y_: task_train_labels[ offset:offset + samples_at_a_time], model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: False }) if KEEP_EPISODIC_MEMORY_FULL: update_episodic_memory( data_to_sample_from, features, episodic_mem_size, task, episodic_images, episodic_labels, task_labels=np.arange(TOTAL_CLASSES), is_herding=True) else: imp_images, imp_labels = sample_from_dataset_icarl( data_to_sample_from, features, np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS) else: # Random sampling # Do the uniform sampling importance_array = np.ones(num_train_examples, dtype=np.float32) if KEEP_EPISODIC_MEMORY_FULL: update_episodic_memory(data_to_sample_from, importance_array, episodic_mem_size, task, episodic_images, episodic_labels) else: imp_images, imp_labels = sample_from_dataset( data_to_sample_from, importance_array, np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS) if not KEEP_EPISODIC_MEMORY_FULL: # Fill the memory to always keep M/T samples per task total_imp_samples = imp_images.shape[0] eps_offset = task * total_imp_samples episodic_images[eps_offset:eps_offset + total_imp_samples] = imp_images episodic_labels[eps_offset:eps_offset + total_imp_samples] = imp_labels episodic_filled_counter += total_imp_samples # Inspect episodic memory if DEBUG_EPISODIC_MEMORY: # Which labels are present in the memory unique_labels = np.unique( np.nonzero(episodic_labels)[-1]) print( 'Unique Labels present in the episodic memory'. format(unique_labels)) print('Labels count:') for lbl in unique_labels: print('Label {}: {} samples'.format( lbl, np.where( np.nonzero(episodic_labels)[-1] == lbl) [0].size)) # Is there any space which is not filled print('Empty space: {}'.format( np.where( np.sum(episodic_labels, axis=1) == 0))) print('Episodic memory of {} images at task {} saved!'. format(episodic_images.shape[0], task)) # If sampling flag is set, store few of the samples from previous task if do_sampling: # Do the uniform sampling/ only get examples from current task importance_array = np.ones( [datasets[task]['train']['images'].shape[0]], dtype=np.float32) # Get the important samples from the current task imp_images, imp_labels = sample_from_dataset( datasets[task]['train'], importance_array, np.arange(TOTAL_CLASSES), SAMPLES_PER_CLASS) if imp_images is not None: if last_task_x is None: last_task_x = imp_images last_task_y_ = imp_labels else: last_task_x = np.concatenate( (last_task_x, imp_images), axis=0) last_task_y_ = np.concatenate( (last_task_y_, imp_labels), axis=0) # Delete the importance array now that you don't need it in the current run del importance_array print( '\t\t\t\tEpisodic memory of {} is saved for Task {}!'. format(imp_labels.shape[0], task)) if train_single_epoch and not cross_validate_mode: fbatch = test_task_sequence(model, sess, datasets, False, eval_single_head=eval_single_head) ftask.append(fbatch) ftask = np.array(ftask) else: if MEASURE_PERF_ON_EPS_MEMORY: eps_mem = { 'images': episodic_images, 'labels': episodic_labels, } # Measure perf on episodic memory ftask = test_task_sequence( model, sess, eps_mem, online_cross_val, eval_single_head=eval_single_head) else: # List to store accuracy for all the tasks for the current trained model ftask = test_task_sequence( model, sess, datasets, online_cross_val, eval_single_head=eval_single_head) # Store the accuracies computed at task T in a list evals.append(ftask) # Reset the optimizer model.reset_optimizer(sess) #-> End for loop task runs.append(np.array(evals)) # End for loop runid runs = np.array(runs) return runs
def train_task_sequence(model, sess, args): """ Train and evaluate LLL system such that we only see a example once Args: Returns: dict A dictionary containing mean and stds for the experiment """ # List to store accuracy for each run runs = [] batch_size = args.batch_size if model.imp_method == 'A-GEM' or model.imp_method == 'MER' or 'ER-' in model.imp_method: use_episodic_memory = True else: use_episodic_memory = False # Loop over number of runs to average over for runid in range(args.num_runs): print('\t\tRun %d:' % (runid)) # Initialize the random seeds np.random.seed(args.random_seed + runid) time_start = time.time() # Load the permute mnist dataset datasets = construct_permute_mnist(model.num_tasks) time_end = time.time() time_spent = time_end - time_start print('Data loading time: {}'.format(time_spent)) episodic_mem_size = args.mem_size * model.num_tasks * TOTAL_CLASSES # Initialize all the variables in the model sess.run(tf.global_variables_initializer()) # Run the init ops model.init_updates(sess) # List to store accuracies for a run evals = [] # List to store the classes that we have so far - used at test time test_labels = np.arange(TOTAL_CLASSES) if use_episodic_memory: # Reserve a space for episodic memory episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE]) episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES]) count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) episodic_filled_counter = 0 examples_seen_so_far = 0 if model.imp_method == 'ER-Hindsight-Anchors': avg_img_vectors = np.zeros([TOTAL_CLASSES, INPUT_FEATURE_SIZE]) anchor_images = np.zeros( [model.num_tasks * TOTAL_CLASSES, INPUT_FEATURE_SIZE]) anchor_labels = np.zeros( [model.num_tasks * TOTAL_CLASSES, TOTAL_CLASSES]) anchor_count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32) # Mask for softmax # Since all the classes are present in all the tasks so nothing to mask logit_mask = np.ones(TOTAL_CLASSES) if COUNT_VIOLATIONS: violation_count = np.zeros(model.num_tasks) vc = 0 # Training loop for all the tasks for task in range(len(datasets)): print('\t\tTask %d:' % (task)) anchors_counter = task * TOTAL_CLASSES # 1 per class per task # If not the first task then restore weights from previous task if (task > 0): model.restore(sess) if MULTI_TASK: if task == 0: # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = sample_weights[ perm][:args.examples_per_task] for t_ in range(1, len(datasets)): task_train_images = datasets[t_]['train']['images'] task_train_labels = datasets[t_]['train']['labels'] sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = np.concatenate( (train_x, task_train_images[perm][:args.examples_per_task]), axis=0) train_y = np.concatenate( (train_y, task_train_labels[perm][:args.examples_per_task]), axis=0) task_sample_weights = np.concatenate( (task_sample_weights, sample_weights[perm][:args.examples_per_task]), axis=0) perm = np.arange(train_x.shape[0]) np.random.shuffle(perm) train_x = train_x[perm] train_y = train_y[perm] task_sample_weights = task_sample_weights[perm] else: # Skip training for this task continue else: # Extract training images and labels for the current task task_train_images = datasets[task]['train']['images'] task_train_labels = datasets[task]['train']['labels'] # Assign equal weights to all the examples task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32) total_train_examples = task_train_images.shape[0] # Randomly suffle the training examples perm = np.arange(total_train_examples) np.random.shuffle(perm) train_x = task_train_images[perm][:args.examples_per_task] train_y = task_train_labels[perm][:args.examples_per_task] task_sample_weights = task_sample_weights[ perm][:args.examples_per_task] print('Received {} images, {} labels at task {}'.format( train_x.shape[0], train_y.shape[0], task)) # Array to store accuracies when training for task T ftask = [] num_train_examples = train_x.shape[0] # Train a task observing sequence of data if args.train_single_epoch: num_iters = (num_train_examples + batch_size - 1) // batch_size else: num_iters = args.train_iters # Training loop for task T for iters in range(num_iters): if args.train_single_epoch and not args.cross_validate_mode: if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0): # Snapshot the current performance across all tasks after each mini-batch fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val) ftask.append(fbatch) offset = (iters * batch_size) % num_train_examples if (offset + batch_size <= num_train_examples): residual = batch_size else: residual = num_train_examples - offset feed_dict = { model.x: train_x[offset:offset + batch_size], model.y_: train_y[offset:offset + batch_size], model.sample_weights: task_sample_weights[offset:offset + batch_size], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate } if model.imp_method == 'VAN': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'FTR_EXT': if task == 0: _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) else: _, loss = sess.run( [model.train_classifier, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'EWC': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Update fisher after every few iterations if (iters + 1) % model.fisher_update_after == 0: sess.run(model.set_running_fisher) sess.run(model.reset_tmp_fisher) _, _, loss = sess.run( [model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'PI': _, _, _, loss = sess.run([ model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) elif model.imp_method == 'MAS': _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) elif model.imp_method == 'A-GEM': if task == 0: # Normal application of gradients _, loss = sess.run( [model.train_first_task, model.agem_loss], feed_dict=feed_dict) else: ## Compute and store the reference gradients on the previous tasks if episodic_filled_counter <= args.eps_mem_batch: mem_sample_mask = np.arange( episodic_filled_counter) else: # Sample a random subset from episodic memory buffer mem_sample_mask = np.random.choice( episodic_filled_counter, args.eps_mem_batch, replace=False ) # Sample without replacement so that we don't sample an example more than once # Store the reference gradient sess.run(model.store_ref_grads, feed_dict={ model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate }) if COUNT_VIOLATIONS: vc, _, loss = sess.run([ model.violation_count, model.train_subseq_tasks, model.agem_loss ], feed_dict=feed_dict) else: # Compute the gradient for current task and project if need be _, loss = sess.run( [model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict) # Put the batch in the FIFO ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) elif model.imp_method == 'RWALK': # If first iteration of the first task then set the initial value of the running fisher if task == 0 and iters == 0: sess.run([model.set_initial_running_fisher], feed_dict=feed_dict) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Update fisher and importance score after every few iterations if (iters + 1) % model.fisher_update_after == 0: # Update the importance score using distance in riemannian manifold sess.run(model.update_big_omega_riemann) # Now that the score is updated, compute the new value for running Fisher sess.run(model.set_running_fisher) # Store the current value of the weights sess.run(model.weights_delta_old_grouped) # Reset the delta_L sess.run([model.reset_small_omega]) _, _, _, _, loss = sess.run([ model.set_tmp_fisher, model.weights_old_ops_grouped, model.train, model.update_small_omega, model.reg_loss ], feed_dict=feed_dict) elif model.imp_method == 'ER-Reservoir': mem_filled_so_far = examples_seen_so_far if ( examples_seen_so_far <= episodic_mem_size ) else episodic_mem_size er_mem_indices = np.arange(mem_filled_so_far) if ( mem_filled_so_far <= args.eps_mem_batch ) else np.random.choice( mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(er_mem_indices) er_train_x_batch = np.concatenate( (episodic_images[er_mem_indices], train_x[offset:offset + residual]), axis=0) er_train_y_batch = np.concatenate( (episodic_labels[er_mem_indices], train_y[offset:offset + residual]), axis=0) feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) # Reservoir update examples_seen_so_far = update_reservior( train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) elif model.imp_method == 'MER': mem_filled_so_far = examples_seen_so_far if ( examples_seen_so_far <= episodic_mem_size ) else episodic_mem_size mer_mem_indices = np.arange(mem_filled_so_far) if ( mem_filled_so_far <= args.eps_mem_batch ) else np.random.choice( mem_filled_so_far, args.eps_mem_batch, replace=False) np.random.shuffle(mer_mem_indices) mer_train_x_batch = episodic_images[mer_mem_indices] mer_train_y_batch = episodic_labels[mer_mem_indices] sess.run(model.store_theta_i_not_w) for mer_x, mer_y in zip(mer_train_x_batch, mer_train_y_batch): feed_dict = { model.x: np.expand_dims(mer_x, axis=0), model.y_: np.expand_dims(mer_y, axis=0), model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) feed_dict = { model.x: train_x[offset:offset + residual], model.y_: train_y[offset:offset + residual], model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: MER_S * args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) sess.run(model.with_in_batch_reptile_update, feed_dict={model.mer_beta: MER_BETA }) # In the paper this is 'mer_gamma' # Reservoir update examples_seen_so_far = update_reservior( train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far) elif model.imp_method == 'ER-Ring': mem_filled_so_far = episodic_filled_counter if ( episodic_filled_counter <= episodic_mem_size ) else episodic_mem_size er_mem_indices = np.arange(mem_filled_so_far) if ( mem_filled_so_far <= args.eps_mem_batch ) else np.random.choice( mem_filled_so_far, args.eps_mem_batch, replace=False) er_train_x_batch = np.concatenate( (episodic_images[er_mem_indices], train_x[offset:offset + residual]), axis=0) er_train_y_batch = np.concatenate( (episodic_labels[er_mem_indices], train_y[offset:offset + residual]), axis=0) feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate } _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict) # Put the batch in the FIFO ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) elif model.imp_method == 'ER-Hindsight-Anchors': anchor_mem_indices = np.arange(anchors_counter) if ( anchors_counter <= args.eps_mem_batch ) else np.random.choice( anchors_counter, args.eps_mem_batch, replace=False) mem_filled_so_far = episodic_filled_counter if ( episodic_filled_counter <= episodic_mem_size ) else episodic_mem_size er_mem_indices = np.arange(mem_filled_so_far) if ( mem_filled_so_far <= args.eps_mem_batch ) else np.random.choice( mem_filled_so_far, args.eps_mem_batch, replace=False) er_train_x_batch = np.concatenate( (episodic_images[er_mem_indices], train_x[offset:offset + residual]), axis=0) er_train_y_batch = np.concatenate( (episodic_labels[er_mem_indices], train_y[offset:offset + residual]), axis=0) feed_dict = { model.x: er_train_x_batch, model.y_: er_train_y_batch, model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0, model.output_mask: logit_mask, model.learning_rate: args.learning_rate } feed_dict[model.phi_hat_alpha] = 0.5 if task == 0: # Just train on the current task dataset phi_hat, _, loss = sess.run( [model.phi_hat, model.train, model.reg_loss], feed_dict=feed_dict) anchor_loss = 0.0 task_loss = 0.0 # Update the FIFO buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], anchor_images, anchor_labels, np.arange(TOTAL_CLASSES), 1, anchor_count_cls, anchors_counter) else: # Train on the anchoring loss feed_dict[model.anchor_points] = anchor_images[ anchor_mem_indices] phi_hat, _, anchor_loss, task_loss, loss = sess.run( [ model.phi_hat, model.train_anchor, model.anchor_loss, model.task_loss, model.final_anchor_loss ], feed_dict=feed_dict) # Update the average image vectors update_avg_image_vectors(train_x[offset:offset + residual], train_y[offset:offset + residual], avg_img_vectors, running_alpha=0.9) # Put the batch in the FIFO ring buffer update_fifo_buffer(train_x[offset:offset + residual], train_y[offset:offset + residual], episodic_images, episodic_labels, np.arange(TOTAL_CLASSES), args.mem_size, count_cls, episodic_filled_counter) if (iters % 10 == 0): print('Step {:d} {:.3f}'.format(iters, loss)) if (math.isnan(loss)): print('ERROR: NaNs NaNs Nans!!!') sys.exit(0) print('\t\t\t\tTraining for Task%d done!' % (task)) if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS: violation_count[task] = vc print('Task {}: Violation Count: {}'.format( task, violation_count)) sess.run(model.reset_violation_count, feed_dict=feed_dict) elif model.imp_method == 'ER-Hindsight-Anchors': if task == 0: # Anchors are already populated pass else: anchor_x = np.zeros([TOTAL_CLASSES, INPUT_FEATURE_SIZE]) anchor_y = np.zeros([TOTAL_CLASSES, TOTAL_CLASSES]) task_labels = np.arange(TOTAL_CLASSES) anchor_x, anchor_y = er_mem_update_hindsight( model, sess, anchor_x, anchor_y, episodic_images, episodic_labels, episodic_filled_counter, task_labels, logit_mask, phi_hat, avg_img_vectors, args) anchor_images[anchors_counter:anchors_counter + TOTAL_CLASSES] = normalize_tensors(anchor_x) anchor_labels[anchors_counter:anchors_counter + TOTAL_CLASSES] = anchor_y # Reset the average image vectors avg_img_vectors[:] = 0.0 # Upaate the episodic memory filled counter if use_episodic_memory: episodic_filled_counter += args.mem_size * TOTAL_CLASSES print('Unique labels in the episodic memory: {}'.format( np.unique(np.nonzero(episodic_labels)[1]))) print('Labels in the episodic memory: {}'.format( np.nonzero(episodic_labels)[1])) # Compute the inter-task updates, Fisher/ importance scores etc # Don't calculate the task updates for the last task if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY: model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES)) print('\t\t\t\tTask updates after Task%d done!' % (task)) if args.train_single_epoch and not args.cross_validate_mode: fbatch = test_task_sequence(model, sess, datasets, False) ftask.append(fbatch) ftask = np.array(ftask) else: if MEASURE_PERF_ON_EPS_MEMORY: eps_mem = { 'images': episodic_images, 'labels': episodic_labels, } # Measure perf on episodic memory ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val) else: # List to store accuracy for all the tasks for the current trained model ftask = test_task_sequence(model, sess, datasets, args.online_cross_val) print('Task: {}, Acc: {}'.format(task, ftask)) # Store the accuracies computed at task T in a list evals.append(ftask) # Reset the optimizer model.reset_optimizer(sess) #-> End for loop task runs.append(np.array(evals)) # End for loop runid runs = np.array(runs) return runs