def run_meta_iteration(i_iter): # In each meta-iteration we draw a meta-batch of several tasks # Then we take a grad step with theta. # Generate the data sets of the training-tasks for meta-batch: mb_data_loaders = task_generator.create_meta_batch( prm, meta_batch_size, meta_split='meta_train') # For each task, prepare an iterator to generate training batches: mb_iterators = [ iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size) ] # Get objective based on tasks in meta-batch: total_objective, info = meta_step(prm, model, mb_data_loaders, mb_iterators, loss_criterion) # Take gradient step with the meta-parameters (theta) based on validation data: grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr, i_iter) # Print status: log_interval = 5 if (i_iter) % log_interval == 0: batch_acc = info['correct_count'] / info['sample_count'] print( cmn.status_string(i_iter, n_iterations, 1, 1, batch_acc, total_objective.data[0]))
def run_train_epoch(i_epoch): # For each task, prepare an iterator to generate training batches: train_iterators = [ iter(train_data_loaders[ii]['train']) for ii in range(n_tasks) ] # The task order to take batches from: task_order = [] task_ids_list = list(range(n_tasks)) for i_batch in range(n_batches_per_task): random.shuffle(task_ids_list) task_order += task_ids_list # each meta-batch includes several tasks # we take a grad step with theta after each meta-batch meta_batch_starts = list(range(0, len(task_order), prm.meta_batch_size)) n_meta_batches = len(meta_batch_starts) # ----------- meta-batches loop (batches of tasks) -----------------------------------# for i_meta_batch in range(n_meta_batches): meta_batch_start = meta_batch_starts[i_meta_batch] task_ids_in_meta_batch = task_order[meta_batch_start:( meta_batch_start + prm.meta_batch_size)] n_tasks_in_batch = len( task_ids_in_meta_batch ) # it may be less than prm.meta_batch_size at the last one # note: it is OK if some task appear several times in the meta-batch mb_data_loaders = [ train_data_loaders[task_id] for task_id in task_ids_in_meta_batch ] mb_iterators = [ train_iterators[task_id] for task_id in task_ids_in_meta_batch ] # Get objective based on tasks in meta-batch: total_objective, info = meta_step(prm, model, mb_data_loaders, mb_iterators, loss_criterion) # Take gradient step with the meta-parameters (theta) based on validation data: grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr, i_epoch) # Print status: log_interval = 200 if i_meta_batch % log_interval == 0: batch_acc = info['correct_count'] / info['sample_count'] print( cmn.status_string(i_epoch, num_epochs, i_meta_batch, n_meta_batches, batch_acc, total_objective.item()))