Example #1
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                curr_empirical_loss, curr_complexity = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + task_complexity

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        task_empirical_loss.item(), task_complexity.item()))

                data_objective.append(total_objective.item())
                data_accuracy.append(batch_acc)
                data_emp_loss.append(task_empirical_loss.item())
                data_task_comp.append(task_complexity.item())

            return total_objective.item()
Example #2
0
    def run_train_epoch(i_epoch):

        # # Adjust randomness (eps_std)
        # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule:
        #     if i_epoch > prm.randomness_full_epoch:
        #         eps_std = 1.0
        #     elif i_epoch > prm.randomness_init_epoch:
        #         eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch)
        #     else:
        #         eps_std = 0.0  #  turn off randomness
        #     post_model.set_eps_std(eps_std)

        # post_model.set_eps_std(0.00) # debug

        complexity_term = 0

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # Monte-Carlo iterations:
            empirical_loss = 0
            n_MC = prm.n_MC
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # calculate objective:
                outputs = post_model(inputs)
                empirical_loss_c = loss_criterion(outputs, targets)
                empirical_loss += (1 / n_MC) * empirical_loss_c

            #  complexity/prior term:
            if prior_model:
                empirical_loss, complexity_term = get_bayes_task_objective(
                    prm, prior_model, post_model, n_train_samples,
                    empirical_loss)
            else:
                complexity_term = 0.0

                # Total objective:
            objective = empirical_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 500
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, objective.data[0])
                    + ' Loss: {:.4}\t Comp.: {:.4}'.format(
                        get_value(empirical_loss), get_value(complexity_term)))
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators,
                  mb_posteriors_models, loss_criterion, n_train_tasks):
    '''  Calculate objective based on tasks in meta-batch '''
    # note: it is OK if some tasks appear several times in the meta-batch

    n_tasks_in_mb = len(mb_data_loaders)

    sum_empirical_loss = 0
    sum_intra_task_comp = 0
    correct_count = 0
    sample_count = 0
    #set_trace()

    # KLD between hyper-posterior and hyper-prior:
    hyper_kl = (1 / (2 * prm.kappa_prior**2)) * net_norm(
        prior_model, p=2)  #net_norm is L2-regularization

    # Hyper-prior term:
    meta_complex_term = get_meta_complexity_term(hyper_kl, prm, n_train_tasks)
    sum_w_kld = 0.0
    sum_b_kld = 0.0

    # ----------- loop over tasks in meta-batch -----------------------------------#
    for i_task in range(n_tasks_in_mb):

        n_samples = mb_data_loaders[i_task]['n_train_samples']

        # get sample-batch data from current task to calculate the empirical loss estimate:
        batch_data = data_gen.get_next_batch_cyclic(
            mb_iterators[i_task], mb_data_loaders[i_task]['train'])

        # The posterior model corresponding to the task in the batch:
        post_model = mb_posteriors_models[i_task]
        post_model.train()

        # Monte-Carlo iterations:
        n_MC = prm.n_MC
        task_empirical_loss = 0
        task_complexity = 0
        # ----------- Monte-Carlo loop  -----------------------------------#
        for i_MC in range(n_MC):
            # get batch variables:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Debug
            # print(targets[0].data[0])  # print first image label
            # import matplotlib.pyplot as plt
            # plt.imshow(inputs[0].cpu().data[0].numpy())  # show first image
            # plt.show()

            # Empirical Loss on current task:
            outputs = post_model(inputs)
            curr_empirical_loss = loss_criterion(outputs, targets)

            correct_count += count_correct(outputs, targets)
            sample_count += inputs.size(0)

            # Intra-task complexity of current task:
            curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective(
                prm,
                prior_model,
                post_model,
                n_samples,
                curr_empirical_loss,
                hyper_kl,
                n_train_tasks=n_train_tasks)

            sum_w_kld += task_info["w_kld"]
            sum_b_kld += task_info["b_kld"]
            task_empirical_loss += (1 / n_MC) * curr_empirical_loss
            task_complexity += (1 / n_MC) * curr_complexity
        # end Monte-Carlo loop

        sum_empirical_loss += task_empirical_loss
        sum_intra_task_comp += task_complexity

    # end loop over tasks in meta-batch
    avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss
    avg_intra_task_comp = (1 / n_tasks_in_mb) * sum_intra_task_comp
    avg_w_kld += (1 / n_tasks_in_mb) * sum_w_kld
    avg_b_kld += (1 / n_tasks_in_mb) * sum_b_kld

    # Approximated total objective:
    total_objective = avg_empirical_loss + prm.task_complex_w * avg_intra_task_comp + prm.meta_complex_w * meta_complex_term

    info = {
        'sample_count': get_value(sample_count),
        'correct_count': get_value(correct_count),
        'avg_empirical_loss': get_value(avg_empirical_loss),
        'avg_intra_task_comp': get_value(avg_intra_task_comp),
        'meta_comp': get_value(meta_complex_term),
        'w_kld': avg_w_kld,
        'b_kld': avg_b_kld
    }
    return total_objective, info
Example #4
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        train_info = {}
        train_info["task_comp"] = 0.0
        train_info["total_loss"] = 0.0

        cnt = 0
        for batch_idx, batch_data in enumerate(train_loader):
            cnt += 1

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                #hyper_kl = 0 when testing
                curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + prm.task_complex_w * task_complexity

            train_info["task_comp"] += task_complexity.data[0]
            train_info["total_loss"] += total_objective.data[0]

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                write_to_log(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.data[0]) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}, w_kld {:.4}, b_kld {:.4}'
                    .format(task_empirical_loss.data[0],
                            task_complexity.data[0], task_info["w_kld"],
                            task_info["b_kld"]), prm)

        train_info["task_comp"] /= cnt
        train_info["total_loss"] /= cnt
        return train_info