def run_meta_iteration(i_iter):
        # In each meta-iteration we draw a meta-batch of several tasks
        # Then we take a grad step with theta.

        # Generate the data sets of the training-tasks for meta-batch:
        mb_data_loaders = task_generator.create_meta_batch(
            prm, meta_batch_size, meta_split='meta_train')

        # For each task, prepare an iterator to generate training batches:
        mb_iterators = [
            iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size)
        ]

        # Get objective based on tasks in meta-batch:
        total_objective, info = meta_step(prm, model, mb_data_loaders,
                                          mb_iterators, loss_criterion)

        # Take gradient step with the meta-parameters (theta) based on validation data:
        grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr, i_iter)

        # Print status:
        log_interval = 5
        if (i_iter) % log_interval == 0:
            batch_acc = info['correct_count'] / info['sample_count']
            print(
                cmn.status_string(i_iter, n_iterations, 1, 1, batch_acc,
                                  total_objective.data[0]))
    def run_train_epoch(i_epoch):

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [
            iter(train_data_loaders[ii]['train']) for ii in range(n_tasks)
        ]

        # The task order to take batches from:
        task_order = []
        task_ids_list = list(range(n_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list

        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        meta_batch_starts = list(range(0, len(task_order),
                                       prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start:(
                meta_batch_start + prm.meta_batch_size)]
            n_tasks_in_batch = len(
                task_ids_in_meta_batch
            )  # it may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some task appear several times in the meta-batch

            mb_data_loaders = [
                train_data_loaders[task_id]
                for task_id in task_ids_in_meta_batch
            ]
            mb_iterators = [
                train_iterators[task_id] for task_id in task_ids_in_meta_batch
            ]

            # Get objective based on tasks in meta-batch:
            total_objective, info = meta_step(prm, model, mb_data_loaders,
                                              mb_iterators, loss_criterion)

            # Take gradient step with the meta-parameters (theta) based on validation data:
            grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr,
                      i_epoch)

            # Print status:
            log_interval = 200
            if i_meta_batch % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, num_epochs, i_meta_batch,
                                      n_meta_batches, batch_acc,
                                      total_objective.item()))