Esempio n. 1
0
def meta_step(prm, model, mb_data_loaders, mb_iterators, loss_criterion):

    total_objective = 0
    correct_count = 0
    sample_count = 0

    n_tasks_in_mb = len(mb_data_loaders)

    # ----------- loop over tasks in meta-batch -----------------------------------#
    for i_task in range(n_tasks_in_mb):

        fast_weights = OrderedDict(
            (name, param) for (name, param) in model.named_parameters())

        # ----------- gradient steps loop -----------------------------------#
        for i_step in range(prm.n_meta_train_grad_steps):

            # get batch variables:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['train'])
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Debug
            # print(targets[0].data[0])  # print first image label
            # import matplotlib.pyplot as plt
            # plt.imshow(inputs[0].cpu().data[0].numpy())  # show first image
            # plt.show()

            if i_step == 0:
                outputs = model(inputs)
            else:
                outputs = model(inputs, fast_weights)
            # Empirical Loss on current task:
            task_loss = loss_criterion(outputs, targets)
            grads = torch.autograd.grad(task_loss,
                                        fast_weights.values(),
                                        create_graph=True)

            fast_weights = OrderedDict(
                (name, param - prm.alpha * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads))
        # end grad steps loop

        # Sample new  (validation) data batch for this task:
        if hasattr(prm, 'MAML_Use_Test_Data') and prm.MAML_Use_Test_Data:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['test'])
        else:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['train'])

        inputs, targets = data_gen.get_batch_vars(batch_data, prm)
        outputs = model(inputs, fast_weights)
        total_objective += loss_criterion(outputs, targets)
        correct_count += count_correct(outputs, targets)
        sample_count += inputs.size(0)
    # end loop over tasks in  meta-batch

    info = {'sample_count': sample_count, 'correct_count': correct_count}
    return total_objective, info
    def get_risk(prior_model, prm, mb_data_loaders, mb_iterators,
                 loss_criterion, n_train_tasks):
        '''  Calculate objective based on tasks in meta-batch '''
        # note: it is OK if some tasks appear several times in the meta-batch

        n_tasks_in_mb = len(mb_data_loaders)

        sum_empirical_loss = 0
        sum_intra_task_comp = 0
        correct_count = 0
        sample_count = 0

        # ----------- loop over tasks in meta-batch -----------------------------------#
        for i_task in range(n_tasks_in_mb):

            n_samples = mb_data_loaders[i_task]['n_train_samples']

            # get sample-batch data from current task to calculate the empirical loss estimate:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['train'])

            # The posterior model corresponding to the task in the batch:
            # post_model = mb_posteriors_models[i_task]
            prior_model.train()

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            # task_complexity = 0
            # ----------- Monte-Carlo loop  -----------------------------------#
            for i_MC in range(n_MC):
                # get batch variables:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Empirical Loss on current task:
                outputs = prior_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
            # end Monte-Carlo loop

            sum_empirical_loss += task_empirical_loss

        # end loop over tasks in meta-batch
        avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss

        return avg_empirical_loss
Esempio n. 3
0
    def run_meta_test_learning(task_model, train_loader):

        task_model.train()
        train_loader_iter = iter(train_loader)

        # Gradient steps (training) loop
        for i_grad_step in range(prm.n_meta_test_grad_steps):
            # get batch:
            batch_data = data_gen.get_next_batch_cyclic(
                train_loader_iter, train_loader)
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Calculate empirical loss:
            outputs = task_model(inputs)
            task_objective = loss_criterion(outputs, targets)

            # Take gradient step with the task weights:
            grad_step(task_objective, task_optimizer)

        # end gradient step loop

        return task_model
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators,
                  mb_posteriors_models, loss_criterion, n_train_tasks):
    '''  Calculate objective based on tasks in meta-batch '''
    # note: it is OK if some tasks appear several times in the meta-batch

    n_tasks_in_mb = len(mb_data_loaders)

    sum_empirical_loss = 0
    sum_intra_task_comp = 0
    correct_count = 0
    sample_count = 0
    #set_trace()

    # KLD between hyper-posterior and hyper-prior:
    hyper_kl = (1 / (2 * prm.kappa_prior**2)) * net_norm(
        prior_model, p=2)  #net_norm is L2-regularization

    # Hyper-prior term:
    meta_complex_term = get_meta_complexity_term(hyper_kl, prm, n_train_tasks)
    sum_w_kld = 0.0
    sum_b_kld = 0.0

    # ----------- loop over tasks in meta-batch -----------------------------------#
    for i_task in range(n_tasks_in_mb):

        n_samples = mb_data_loaders[i_task]['n_train_samples']

        # get sample-batch data from current task to calculate the empirical loss estimate:
        batch_data = data_gen.get_next_batch_cyclic(
            mb_iterators[i_task], mb_data_loaders[i_task]['train'])

        # The posterior model corresponding to the task in the batch:
        post_model = mb_posteriors_models[i_task]
        post_model.train()

        # Monte-Carlo iterations:
        n_MC = prm.n_MC
        task_empirical_loss = 0
        task_complexity = 0
        # ----------- Monte-Carlo loop  -----------------------------------#
        for i_MC in range(n_MC):
            # get batch variables:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Debug
            # print(targets[0].data[0])  # print first image label
            # import matplotlib.pyplot as plt
            # plt.imshow(inputs[0].cpu().data[0].numpy())  # show first image
            # plt.show()

            # Empirical Loss on current task:
            outputs = post_model(inputs)
            curr_empirical_loss = loss_criterion(outputs, targets)

            correct_count += count_correct(outputs, targets)
            sample_count += inputs.size(0)

            # Intra-task complexity of current task:
            curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective(
                prm,
                prior_model,
                post_model,
                n_samples,
                curr_empirical_loss,
                hyper_kl,
                n_train_tasks=n_train_tasks)

            sum_w_kld += task_info["w_kld"]
            sum_b_kld += task_info["b_kld"]
            task_empirical_loss += (1 / n_MC) * curr_empirical_loss
            task_complexity += (1 / n_MC) * curr_complexity
        # end Monte-Carlo loop

        sum_empirical_loss += task_empirical_loss
        sum_intra_task_comp += task_complexity

    # end loop over tasks in meta-batch
    avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss
    avg_intra_task_comp = (1 / n_tasks_in_mb) * sum_intra_task_comp
    avg_w_kld += (1 / n_tasks_in_mb) * sum_w_kld
    avg_b_kld += (1 / n_tasks_in_mb) * sum_b_kld

    # Approximated total objective:
    total_objective = avg_empirical_loss + prm.task_complex_w * avg_intra_task_comp + prm.meta_complex_w * meta_complex_term

    info = {
        'sample_count': get_value(sample_count),
        'correct_count': get_value(correct_count),
        'avg_empirical_loss': get_value(avg_empirical_loss),
        'avg_intra_task_comp': get_value(avg_intra_task_comp),
        'meta_comp': get_value(meta_complex_term),
        'w_kld': avg_w_kld,
        'b_kld': avg_b_kld
    }
    return total_objective, info
Esempio n. 5
0
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators,
                  mb_posteriors_models, loss_criterion, n_train_tasks):
    '''  Calculate objective based on tasks in meta-batch '''
    # note: it is OK if some tasks appear several times in the meta-batch

    n_tasks_in_mb = len(mb_data_loaders)

    correct_count = 0
    sample_count = 0

    # Hyper-prior term:
    hyper_dvrg = get_hyper_divergnce(prm, prior_model)
    meta_complex_term = get_meta_complexity_term(hyper_dvrg, prm,
                                                 n_train_tasks)

    avg_empiric_loss_per_task = torch.zeros(n_tasks_in_mb, device=prm.device)
    complexity_per_task = torch.zeros(n_tasks_in_mb, device=prm.device)
    n_samples_per_task = torch.zeros(
        n_tasks_in_mb, device=prm.device
    )  # how many sampels there are total in each task (not just in a batch)

    # ----------- loop over tasks in meta-batch -----------------------------------#
    for i_task in range(n_tasks_in_mb):

        n_samples = mb_data_loaders[i_task]['n_train_samples']
        n_samples_per_task[i_task] = n_samples

        # get sample-batch data from current task to calculate the empirical loss estimate:
        batch_data = data_gen.get_next_batch_cyclic(
            mb_iterators[i_task], mb_data_loaders[i_task]['train'])

        # get batch variables:
        inputs, targets = data_gen.get_batch_vars(batch_data, prm)
        batch_size = inputs.shape[0]

        # The posterior model corresponding to the task in the batch:
        post_model = mb_posteriors_models[i_task]
        post_model.train()

        # Monte-Carlo iterations:
        n_MC = prm.n_MC

        avg_empiric_loss = 0.0
        complexity = 0.0

        # Monte-Carlo loop
        for i_MC in range(n_MC):

            # Debug
            # print(targets[0].data[0])  # print first image label
            # import matplotlib.pyplot as plt
            # plt.imshow(inputs[0].cpu().data[0].numpy())  # show first image
            # plt.show()

            # Empirical Loss on current task:
            outputs = post_model(inputs)
            avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                outputs, targets)

            correct_count += count_correct(outputs, targets)  # for print
            sample_count += inputs.size(0)

            # Intra-task complexity of current task:
            # curr_complexity = get_task_complexity(prm, prior_model, post_model,
            #     n_samples, avg_empiric_loss_curr, hyper_dvrg, n_train_tasks=n_train_tasks, noised_prior=True)

            avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr
            # complexity +=  (1 / n_MC) * curr_complexity
        # end Monte-Carlo loop

        complexity = get_task_complexity(prm,
                                         prior_model,
                                         post_model,
                                         n_samples,
                                         avg_empiric_loss,
                                         hyper_dvrg,
                                         n_train_tasks=n_train_tasks,
                                         noised_prior=True)
        avg_empiric_loss_per_task[i_task] = avg_empiric_loss
        complexity_per_task[i_task] = complexity
    # end loop over tasks in meta-batch

    # Approximated total objective:
    if prm.complexity_type == 'Variational_Bayes':
        # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
        #  but its weight in the objective should be considered by how many samples there are total in the task
        total_objective =\
            (avg_empiric_loss_per_task * n_samples_per_task + complexity_per_task).mean() * n_train_tasks + meta_complex_term
        # total_objective = ( avg_empiric_loss_per_task * n_samples_per_task + complexity_per_task).mean() + meta_complex_term

    else:
        total_objective =\
            avg_empiric_loss_per_task.mean() + complexity_per_task.mean() + meta_complex_term

    info = {
        'sample_count': sample_count,
        'correct_count': correct_count,
        'avg_empirical_loss': avg_empiric_loss_per_task.mean().item(),
        'avg_intra_task_comp': complexity_per_task.mean().item(),
        'meta_comp': meta_complex_term.item()
    }
    return total_objective, info