Пример #1
0
def meta_train(train_subset=train_set):
    #region PREPARING DATALOADER
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_total_samples_per_class,
                                       device=device)
        # create dummy sampler
        all_class = [0] * 100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    else:
        all_class = all_class_train
        embedding = embedding_train
        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(
            all_class.keys()),
                                                         replacement=False)
        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    #endregion
    print('Start to train...')
    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = []  # meta loss to save
        val_accuracies = []
        train_accuracies = []

        meta_loss = 0  # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0  # compute loss average to print

        meta_loss_avg_save = []  # meta loss to save

        task_count = 0  # a counter to decide when a minibatch of task is completed to perform meta update

        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True)
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class, embedding, class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class, device)

                loss_NLL = get_task_prediction(x_t, y_t, x_v, y_v)

                if torch.isnan(loss_NLL).item():
                    sys.exit('NaN error')

                # accumulate meta loss
                meta_loss = meta_loss + loss_NLL

                task_count = task_count + 1

                if task_count % num_tasks_per_minibatch == 0:
                    meta_loss = meta_loss / num_tasks_per_minibatch

                    # accumulate into different variables for printing purpose
                    meta_loss_avg_print += meta_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward()
                    op_theta.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print /
                                                  num_meta_updates_count)
                        print('{0:d}, {1:2.4f}'.format(task_count,
                                                       meta_loss_avg_save[-1]))

                        num_meta_updates_count = 0
                        meta_loss_avg_print = 0

                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))

                        meta_loss_avg_save = []

                        # print('Saving loss...')
                        # val_accs, _ = meta_validation(
                        #     datasubset=val_set,
                        #     num_val_tasks=num_val_tasks,
                        #     return_uncertainty=False)
                        # val_acc = np.mean(val_accs)
                        # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        # val_accuracies.append(val_acc)

                        # train_accs, _ = meta_validation(
                        #     datasubset=train_set,
                        #     num_val_tasks=num_val_tasks,
                        #     return_uncertainty=False)
                        # train_acc = np.mean(train_accs)
                        # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        # train_accuracies.append(train_acc)

                    # reset meta loss
                    meta_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1) % num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder,
                                                checkpoint_filename))
        print()
Пример #2
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(
            num_samples=num_total_samples_per_class,
            device=device
        )

        # create dummy sampler
        all_class = [0]*100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True
        )
    else:
        all_class = all_class_train
        # all_class.update(all_class_val)
        embedding = embedding_train
        # embedding.update(embedding_val)

        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(all_class.keys()), replacement=False)

        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True
        )
    
    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = [] # meta loss to save
        kl_loss_saved = []
        val_accuracies = []
        train_accuracies = []

        task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update
        meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0 # compute loss average to print

        kl_loss = 0
        kl_loss_avg_print = 0

        meta_loss_avg_save = [] # meta loss to save
        kl_loss_avg_save = []

        task_count = 0
        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True
                    )
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class,
                        embedding,
                        class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class,
                        device
                    )
                loss_NLL, KL_loss = get_task_prediction(
                    x_t=x_t,
                    y_t=y_t,
                    x_v=x_v,
                    y_v=y_v,
                    p_dropout=p_base_dropout
                )

                meta_loss = meta_loss + loss_NLL

                kl_loss = kl_loss + KL_loss
                task_count = task_count + 1

                if torch.isnan(meta_loss).item():
                    sys.exit('nan')

                if (task_count % num_tasks_per_minibatch == 0):
                    # average over the number of tasks per minibatch
                    meta_loss = meta_loss/num_tasks_per_minibatch
                    kl_loss = kl_loss/num_tasks_per_minibatch

                    # accumulate for printing purpose
                    meta_loss_avg_print += meta_loss.item()
                    kl_loss_avg_print += kl_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward(retain_graph=True)
                    # torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=1)
                    op_theta.step()
                    
                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count)
                        kl_loss_avg_save.append(kl_loss_avg_print/num_meta_updates_count)
                        print('{0:d}, {1:2.4f}, {2:1.4f}'.format(
                            task_count,
                            meta_loss_avg_save[-1],
                            kl_loss_avg_save[-1]
                        ))
                        num_meta_updates_count = 0

                        meta_loss_avg_print = 0
                        kl_loss_avg_print = 0
                        
                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))
                        kl_loss_saved.append(np.mean(kl_loss_avg_save))

                        meta_loss_avg_save = []
                        kl_loss_avg_save = []
                        # if datasource != 'sine_line':
                        #     val_accs = meta_validation(
                        #         datasubset=val_set,
                        #         num_val_tasks=num_val_tasks)
                        #     val_acc = np.mean(val_accs)
                        #     val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        #     print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        #     val_accuracies.append(val_acc)

                        #     train_accs = meta_validation(
                        #         datasubset=train_set,
                        #         num_val_tasks=num_val_tasks)
                        #     train_acc = np.mean(train_accs)
                        #     train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        #     print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        #     train_accuracies.append(train_acc)

                    # reset meta loss for the next minibatch of tasks
                    meta_loss = 0
                    kl_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1)% num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'kl_loss': kl_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict(),
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename))
            # scheduler.step()
        print()
Пример #3
0
def meta_train(params, amine=None):
    # Start by unpacking the variables that we need
    datasource = params['datasource']
    num_total_samples_per_class = params['num_total_samples_per_class']
    device = params['device']
    num_classes_per_task = params['num_classes_per_task']
    num_training_samples_per_class = params['num_training_samples_per_class']
    num_tasks_save_loss = params['num_tasks_save_loss']

    # Epoch variables
    num_epochs = params['num_epochs']
    resume_epoch = params['resume_epoch']
    num_tasks_per_epoch = params['num_tasks_per_epoch']

    # Note we have lowercase theta here vs with PLATIPUS
    theta = params['theta']
    op_theta = params['op_theta']

    # How often should we do a printout?
    num_meta_updates_print = 1
    # How often should we save?
    num_epochs_save = 1000

    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_total_samples_per_class,
                                       device=device)

    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        print(f"Starting epoch {epoch}")

        if datasource == 'drp_chem':
            training_batches = params['training_batches']
            if params['cross_validate']:
                b_num = np.random.choice(len(training_batches[amine]))
                batch = training_batches[amine][b_num]
            else:
                b_num = np.random.choice(len(training_batches))
                batch = training_batches[b_num]
            x_train, y_train, x_val, y_val = torch.from_numpy(batch[0]).float().to(params['device']), torch.from_numpy(batch[1]).long().to(params['device']), \
                torch.from_numpy(batch[2]).float().to(params['device']), torch.from_numpy(batch[3]).long().to(params['device'])

        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = []  # meta loss to save
        val_accuracies = []
        train_accuracies = []

        task_count = 0  # a counter to decide when a minibatch of task is completed to perform meta update
        meta_loss = 0  # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0  # compute loss average to print

        meta_loss_avg_save = []  # meta loss to save

        while (task_count < num_tasks_per_epoch):
            if datasource == 'sine_line':
                p_sine = params['p_sine']
                x_t, y_t, x_v, y_v = get_task_sine_line_data(
                    data_generator=data_generator,
                    p_sine=p_sine,
                    num_training_samples=num_training_samples_per_class,
                    noise_flag=True)
            elif datasource == 'drp_chem':
                x_t, y_t, x_v, y_v = x_train[task_count], y_train[
                    task_count], x_val[task_count], y_val[task_count]
            else:
                sys.exit('Unknown dataset')

            loss_NLL = get_task_prediction(x_t, y_t, x_v, params, y_v)

            if torch.isnan(loss_NLL).item():
                sys.exit('NaN error')

            # accumulate meta loss
            meta_loss = meta_loss + loss_NLL

            task_count = task_count + 1

            if task_count % num_tasks_per_epoch == 0:
                meta_loss = meta_loss / num_tasks_per_epoch

                # accumulate into different variables for printing purpose
                meta_loss_avg_print += meta_loss.item()

                op_theta.zero_grad()
                meta_loss.backward()

                # Clip gradients to prevent exploding gradient problem
                torch.nn.utils.clip_grad_norm_(parameters=theta.values(),
                                               max_norm=3)

                op_theta.step()

                # Printing losses
                num_meta_updates_count += 1
                if (num_meta_updates_count % num_meta_updates_print == 0):
                    meta_loss_avg_save.append(meta_loss_avg_print /
                                              num_meta_updates_count)
                    print('{0:d}, {1:2.4f}'.format(task_count,
                                                   meta_loss_avg_save[-1]))

                    num_meta_updates_count = 0
                    meta_loss_avg_print = 0

                if (task_count % num_tasks_save_loss == 0):
                    meta_loss_saved.append(np.mean(meta_loss_avg_save))

                    meta_loss_avg_save = []

                    # print('Saving loss...')
                    # if datasource != 'sine_line':
                    #     val_accs, _ = meta_validation(
                    #         datasubset=val_set,
                    #         num_val_tasks=num_val_tasks,
                    #         return_uncertainty=False)
                    #     val_acc = np.mean(val_accs)
                    #     val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                    #     print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                    #     val_accuracies.append(val_acc)

                    #     train_accs, _ = meta_validation(
                    #         datasubset=train_set,
                    #         num_val_tasks=num_val_tasks,
                    #         return_uncertainty=False)
                    #     train_acc = np.mean(train_accs)
                    #     train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                    #     print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                    #     train_accuracies.append(train_acc)

                # Reset meta loss
                meta_loss = 0

            if (task_count >= num_tasks_per_epoch):
                break

        if ((epoch + 1) % num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            dst_folder = params['dst_folder']
            torch.save(checkpoint, os.path.join(dst_folder,
                                                checkpoint_filename))
        print()
Пример #4
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_total_samples_per_class,
                                       device=device)

        # create dummy sampler
        all_class = [0] * 100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    else:
        all_class = all_class_train
        # all_class.update(all_class_val)
        embedding = embedding_train
        # embedding.update(embedding_val)

        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(
            all_class.keys()),
                                                         replacement=False)

        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)

    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        loss_NLL_saved = []
        kl_loss_saved = []
        d_loss_saved = []
        val_accuracies = []
        train_accuracies = []

        task_count = 0  # a counter to decide when a minibatch of task is completed to perform meta update
        meta_loss = 0  # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        loss_NLL_v = 0
        loss_NLL_avg_print = 0

        kl_loss = 0
        kl_loss_avg_print = 0

        d_loss = 0
        d_loss_avg_print = 0

        loss_NLL_avg_save = []
        kl_loss_avg_save = []
        d_loss_avg_save = []

        task_count = 0

        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True)
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class, embedding, class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class, device)
                loss_NLL, KL_loss, discriminator_loss = get_task_prediction(
                    x_t=x_t,
                    y_t=y_t,
                    x_v=x_v,
                    y_v=y_v,
                    p_dropout=p_dropout_base,
                    p_dropout_g=p_dropout_generator,
                    p_dropout_d=p_dropout_discriminator,
                    p_dropout_e=p_dropout_encoder)

                if torch.isnan(loss_NLL).item():
                    sys.exit('nan')

                loss_NLL_v += loss_NLL.item()

                if (loss_NLL.item() > 1):
                    loss_NLL.data = torch.tensor([1.], device=device)

                # if (discriminator_loss.item() > d_loss_const):
                #     discriminator_loss.data = torch.tensor([d_loss_const], device=device)

                kl_loss = kl_loss + KL_loss

                if (KL_loss.item() < 0):
                    KL_loss.data = torch.tensor([0.], device=device)

                Ri = torch.sqrt((KL_loss + ri_const) /
                                (2 * (total_validation_samples - 1)))
                meta_loss = meta_loss + loss_NLL + Ri

                d_loss = d_loss + discriminator_loss

                task_count = task_count + 1

                if (task_count % num_tasks_per_minibatch == 0):
                    # average over the number of tasks per minibatch
                    meta_loss = meta_loss / num_tasks_per_minibatch
                    loss_NLL_v /= num_tasks_per_minibatch
                    kl_loss = kl_loss / num_tasks_per_minibatch
                    d_loss = d_loss / num_tasks_per_minibatch

                    # accumulate for printing purpose
                    loss_NLL_avg_print += loss_NLL_v
                    kl_loss_avg_print += kl_loss.item()
                    d_loss_avg_print += d_loss.item()

                    # adding R0
                    R0 = 0
                    for key in theta.keys():
                        R0 += theta[key].norm(2)
                    R0 = torch.sqrt((L2_regularization * R0 + r0_const) /
                                    (2 * (num_tasks_per_minibatch - 1)))
                    meta_loss += R0

                    # optimize theta
                    op_theta.zero_grad()
                    meta_loss.backward(retain_graph=True)
                    torch.nn.utils.clip_grad_norm_(parameters=theta.values(),
                                                   max_norm=clip_grad_value)
                    torch.nn.utils.clip_grad_norm_(
                        parameters=w_encoder.values(),
                        max_norm=clip_grad_value)
                    op_theta.step()

                    # optimize the discriminator
                    op_discriminator.zero_grad()
                    d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        parameters=w_discriminator.values(),
                        max_norm=clip_grad_value)
                    op_discriminator.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        loss_NLL_avg_save.append(loss_NLL_avg_print /
                                                 num_meta_updates_count)
                        kl_loss_avg_save.append(kl_loss_avg_print /
                                                num_meta_updates_count)
                        d_loss_avg_save.append(d_loss_avg_print /
                                               num_meta_updates_count)
                        print('{0:d}, {1:2.4f}, {2:1.4f}, {3:2.4e}'.format(
                            task_count, loss_NLL_avg_save[-1],
                            kl_loss_avg_save[-1], d_loss_avg_save[-1]))
                        num_meta_updates_count = 0

                        loss_NLL_avg_print = 0
                        kl_loss_avg_print = 0
                        d_loss_avg_print = 0

                    if (task_count % num_tasks_save_loss == 0):
                        loss_NLL_saved.append(np.mean(loss_NLL_avg_save))
                        kl_loss_saved.append(np.mean(kl_loss_avg_save))
                        d_loss_saved.append(np.mean(d_loss_avg_save))

                        loss_NLL_avg_save = []
                        kl_loss_avg_save = []
                        d_loss_avg_save = []

                        # if datasource != 'sine_line':
                        #     val_accs = meta_validation(
                        #         datasubset=val_set,
                        #         num_val_tasks=num_val_tasks)
                        #     val_acc = np.mean(val_accs)
                        #     val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        #     print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        #     val_accuracies.append(val_acc)

                        # train_accs = meta_validation(
                        #     datasubset=train_set,
                        #     num_val_tasks=num_val_tasks)
                        # train_acc = np.mean(train_accs)
                        # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        # train_accuracies.append(train_acc)

                    # reset meta loss for the next minibatch of tasks
                    meta_loss = 0
                    kl_loss = 0
                    d_loss = 0
                    loss_NLL_v = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1) % num_epochs_save == 0):
            checkpoint = {
                'w_discriminator': w_discriminator,
                'theta': theta,
                'w_encoder': w_encoder,
                'w_encoder_2': w_encoder_2,
                'meta_loss': loss_NLL_saved,
                'kl_loss': kl_loss_saved,
                'd_loss': d_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict(),
                'op_discriminator': op_discriminator.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder,
                                                checkpoint_filename))
            # scheduler.step()
        print()
Пример #5
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_samples_per_class)
        # create dummy sampler
        all_class_train = [0] * 10
    else:
        all_class_train, all_data_train = load_dataset(
            dataset_name=datasource,
            subset=train_set
        )
        all_class_val, all_data_val = load_dataset(
            dataset_name=datasource,
            subset=val_set
        )
        all_class_train.update(all_class_val)
        all_data_train.update(all_data_val)
        
    # initialize data loader
    train_loader = initialize_dataloader(
        all_classes=[class_label for class_label in all_class_train],
        num_classes_per_task=num_classes_per_task
    )

    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = [] # meta loss to save
        val_accuracies = []
        train_accuracies = []

        meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0 # compute loss average to print

        meta_loss_avg_save = [] # meta loss to save

        task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update
        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True
                    )
                    x_t = torch.tensor(x_t, dtype=torch.float, device=device)
                    y_t = torch.tensor(y_t, dtype=torch.float, device=device)
                    x_v = torch.tensor(x_v, dtype=torch.float, device=device)
                    y_v = torch.tensor(y_v, dtype=torch.float, device=device)
                else:
                    x_t, y_t, x_v, y_v = get_train_val_task_data(
                        all_classes=all_class_train,
                        all_data=all_data_train,
                        class_labels=class_labels,
                        num_samples_per_class=num_samples_per_class,
                        num_training_samples_per_class=num_training_samples_per_class,
                        device=device
                    )

                w_task = adapt_to_task(x=x_t, y=y_t, w0=theta)
                y_pred = net.forward(x=x_v, w=w_task)
                
                loss_NLL = loss_fn(input=y_pred, target=y_v)

                if torch.isnan(loss_NLL).item():
                    sys.exit('NaN error')

                # accumulate meta loss
                meta_loss = meta_loss + loss_NLL

                task_count = task_count + 1

                if task_count % num_tasks_per_minibatch == 0:
                    meta_loss = meta_loss/num_tasks_per_minibatch

                    # accumulate into different variables for printing purpose
                    meta_loss_avg_print += meta_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward()
                    torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=10)
                    op_theta.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count)
                        print('{0:d}, {1:2.4f}'.format(
                            task_count,
                            meta_loss_avg_save[-1]
                        ))

                        num_meta_updates_count = 0
                        meta_loss_avg_print = 0
                    
                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))

                        meta_loss_avg_save = []

                        print('Saving loss...')
                        if datasource != 'sine_line':
                            val_accs = validate_classification(
                                all_classes=all_class_val,
                                all_data=all_data_val,
                                num_val_tasks=100,
                                p_rand=0.5,
                                uncertainty=False,
                                csv_flag=False
                            )
                            val_acc = np.mean(val_accs)
                            val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                            print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                            val_accuracies.append(val_acc)

                            train_accs = validate_classification(
                                all_classes=all_class_train,
                                all_data=all_data_train,
                                num_val_tasks=100,
                                p_rand=0.5,
                                uncertainty=False,
                                csv_flag=False
                            )
                            train_acc = np.mean(train_accs)
                            train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                            print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                            train_accuracies.append(train_acc)
                    
                    # reset meta loss
                    meta_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1)% num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename))
        scheduler.step()
        print()