Пример #1
0
def meta_step(prm, model, mb_data_loaders, mb_iterators, loss_criterion):

    total_objective = 0
    correct_count = 0
    sample_count = 0

    n_tasks_in_mb = len(mb_data_loaders)

    # ----------- loop over tasks in meta-batch -----------------------------------#
    for i_task in range(n_tasks_in_mb):

        fast_weights = OrderedDict(
            (name, param) for (name, param) in model.named_parameters())

        # ----------- gradient steps loop -----------------------------------#
        for i_step in range(prm.n_meta_train_grad_steps):

            # get batch variables:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['train'])
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Debug
            # print(targets[0].data[0])  # print first image label
            # import matplotlib.pyplot as plt
            # plt.imshow(inputs[0].cpu().data[0].numpy())  # show first image
            # plt.show()

            if i_step == 0:
                outputs = model(inputs)
            else:
                outputs = model(inputs, fast_weights)
            # Empirical Loss on current task:
            task_loss = loss_criterion(outputs, targets)
            grads = torch.autograd.grad(task_loss,
                                        fast_weights.values(),
                                        create_graph=True)

            fast_weights = OrderedDict(
                (name, param - prm.alpha * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads))
        # end grad steps loop

        # Sample new  (validation) data batch for this task:
        if hasattr(prm, 'MAML_Use_Test_Data') and prm.MAML_Use_Test_Data:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['test'])
        else:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['train'])

        inputs, targets = data_gen.get_batch_vars(batch_data, prm)
        outputs = model(inputs, fast_weights)
        total_objective += loss_criterion(outputs, targets)
        correct_count += count_correct(outputs, targets)
        sample_count += inputs.size(0)
    # end loop over tasks in  meta-batch

    info = {'sample_count': sample_count, 'correct_count': correct_count}
    return total_objective, info
def run_test_max_posterior(model, test_loader, loss_criterion, prm):

    n_test_samples = len(test_loader.dataset)

    model.eval()
    test_loss = 0
    n_correct = 0
    for batch_data in test_loader:
        inputs, targets = data_gen.get_batch_vars(batch_data,
                                                  prm,
                                                  is_test=True)
        old_eps_std = model.set_eps_std(0.0)  # test with max-posterior
        outputs = model(inputs)
        model.set_eps_std(old_eps_std)  # return model to normal behaviour
        test_loss += loss_criterion(outputs,
                                    targets)  # sum the mean loss in batch
        n_correct += count_correct(outputs, targets)

    test_loss /= n_test_samples
    test_acc = n_correct / n_test_samples
    info = {
        'test_acc': test_acc,
        'n_correct': n_correct,
        'test_type': 'max_posterior',
        'n_test_samples': n_test_samples,
        'test_loss': get_value(test_loss)
    }
    return info
def run_eval_expected(model, loader, prm):
    ''' Estimates the expectation of the loss by monte-carlo averaging'''
    n_samples = len(loader.dataset)
    loss_criterion = get_loss_func(prm)
    model.eval()
    avg_loss = 0.0
    n_correct = 0
    n_MC = prm.n_MC_eval  # number of monte-carlo runs for expected loss estimation
    for batch_data in loader:
        inputs, targets = data_gen.get_batch_vars(batch_data, prm)
        batch_size = inputs.shape[0]
        #  monte-carlo runs
        for i_MC in range(n_MC):
            outputs = model(inputs)
            avg_loss += loss_criterion(
                outputs, targets).item()  # sum the loss contributed from batch
            n_correct += count_correct(outputs, targets)

    avg_loss /= (n_MC * n_samples)
    acc = n_correct / (n_MC * n_samples)
    info = {
        'acc': acc,
        'n_correct': n_correct,
        'n_samples': n_samples,
        'avg_loss': avg_loss
    }
    return info
def run_eval_max_posterior(model, loader, prm):
    ''' Estimates the the loss by using the mean network parameters'''
    # 使用平均网络参数
    n_samples = len(loader.dataset)
    loss_criterion = get_loss_func(prm)
    model.eval()
    avg_loss = 0
    n_correct = 0
    for batch_data in loader:
        # 提取batch data
        inputs, targets = data_gen.get_batch_vars(batch_data, prm)
        batch_size = inputs.shape[0]
        old_eps_std = model.set_eps_std(0.0)  # test with max-posterior
        outputs = model(inputs)
        model.set_eps_std(old_eps_std)  # return model to normal behaviour
        avg_loss += loss_criterion(
            outputs, targets).item()  # sum the loss contributed from batch
        n_correct += count_correct(outputs, targets)

    avg_loss /= n_samples
    acc = n_correct / n_samples
    info = {
        'acc': acc,
        'n_correct': n_correct,
        'n_samples': n_samples,
        'avg_loss': avg_loss
    }
    return info
Пример #5
0
def run_eval_majority_vote(model, loader, prm, n_votes=5):
    ''' Estimates the the loss of the the majority votes over several draws form network's distribution'''

    loss_criterion = get_loss_func(prm)
    n_samples = len(loader.dataset)
    n_test_batches = len(loader)
    model.eval()
    avg_loss = 0
    n_correct = 0
    for batch_data in loader:
        inputs, targets = data_gen.get_batch_vars(batch_data, prm)

        batch_size = inputs.shape[0] # min(prm.test_batch_size, n_samples)
        info = data_gen.get_info(prm)
        n_labels = info['n_classes']
        votes = torch.zeros((batch_size, n_labels), device=prm.device)
        loss_from_batch = 0.0
        for i_vote in range(n_votes):

            outputs = model(inputs)
            loss_from_batch += loss_criterion(outputs, targets).item()
            pred = outputs.data.max(1, keepdim=True)[1]  # get the index of the max output
            for i_sample in range(batch_size):
                pred_val = pred[i_sample].cpu().numpy()[0]
                votes[i_sample, pred_val] += 1
        avg_loss += loss_from_batch / n_votes # sum the loss contributed from batch

        majority_pred = votes.max(1, keepdim=True)[1] # find argmax class for each sample
        n_correct += majority_pred.eq(targets.data.view_as(majority_pred)).cpu().sum()
    avg_loss /= n_samples
    acc = n_correct / n_samples
    info = {'acc': acc, 'n_correct': n_correct,
            'n_samples': n_samples, 'avg_loss': avg_loss}
    return info
Пример #6
0
def run_eval_avg_vote(model, loader, prm, n_votes=5):
    ''' Estimates the the loss by of the average vote over several draws form network's distribution'''

    loss_criterion = get_loss_func(prm)
    n_samples = len(loader.dataset)
    n_test_batches = len(loader)
    model.eval()
    avg_loss = 0
    n_correct = 0
    for batch_data in loader:
        inputs, targets = data_gen.get_batch_vars(batch_data, prm)

        batch_size = min(prm.test_batch_size, n_samples)
        info = data_gen.get_info(prm)
        n_labels = info['n_classes']
        votes = torch.zeros((batch_size, n_labels), device=prm.device)
        loss_from_batch = 0.0
        for i_vote in range(n_votes):

            outputs = model(inputs)
            loss_from_batch += loss_criterion(outputs, targets).item()
            votes += outputs.data

        majority_pred = votes.max(1, keepdim=True)[1]
        n_correct += majority_pred.eq(targets.data.view_as(majority_pred)).cpu().sum()
        avg_loss += loss_from_batch / n_votes  # sum the loss contributed from batch

    avg_loss /= n_samples
    acc = n_correct / n_samples
    info = {'acc': acc, 'n_correct': n_correct,
            'n_samples': n_samples, 'avg_loss': avg_loss}
    return info
Пример #7
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                curr_empirical_loss, curr_complexity = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + task_complexity

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        task_empirical_loss.item(), task_complexity.item()))

                data_objective.append(total_objective.item())
                data_accuracy.append(batch_acc)
                data_emp_loss.append(task_empirical_loss.item())
                data_task_comp.append(task_complexity.item())

            return total_objective.item()
Пример #8
0
    def run_train_epoch(i_epoch):

        # # Adjust randomness (eps_std)
        # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule:
        #     if i_epoch > prm.randomness_full_epoch:
        #         eps_std = 1.0
        #     elif i_epoch > prm.randomness_init_epoch:
        #         eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch)
        #     else:
        #         eps_std = 0.0  #  turn off randomness
        #     post_model.set_eps_std(eps_std)

        # post_model.set_eps_std(0.00) # debug

        complexity_term = 0

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # Monte-Carlo iterations:
            empirical_loss = 0
            n_MC = prm.n_MC
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # calculate objective:
                outputs = post_model(inputs)
                empirical_loss_c = loss_criterion(outputs, targets)
                empirical_loss += (1 / n_MC) * empirical_loss_c

            #  complexity/prior term:
            if prior_model:
                empirical_loss, complexity_term = get_bayes_task_objective(
                    prm, prior_model, post_model, n_train_samples,
                    empirical_loss)
            else:
                complexity_term = 0.0

                # Total objective:
            objective = empirical_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 500
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, objective.data[0])
                    + ' Loss: {:.4}\t Comp.: {:.4}'.format(
                        get_value(empirical_loss), get_value(complexity_term)))
Пример #9
0
    def run_train_epoch(i_epoch, log_mat):

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # get batch data:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            batch_size = inputs.shape[0]

            # Monte-Carlo iterations:
            avg_empiric_loss = torch.zeros(1, device=prm.device)
            n_MC = prm.n_MC

            for i_MC in range(n_MC):

                # calculate objective:
                outputs = post_model(inputs)
                avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                    outputs, targets)
                avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr

            # complexity/prior term:
            if prior_model:
                complexity_term = get_task_complexity(prm, prior_model,
                                                      post_model,
                                                      n_train_samples,
                                                      avg_empiric_loss)
            else:
                complexity_term = torch.zeros(1, device=prm.device)

            # Total objective:
            objective = avg_empiric_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 1000
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, objective.item()) +
                    ' Loss: {:.4}\t Comp.: {:.4}'.format(
                        avg_empiric_loss.item(), complexity_term.item()))

        # End batch loop

        # save results for epochs-figure:
        if figure_flag and (i_epoch % prm.log_figure['interval_epochs'] == 0):
            save_result_for_figure(post_model, prior_model, data_loader, prm,
                                   log_mat, i_epoch)
    def get_risk(prior_model, prm, mb_data_loaders, mb_iterators,
                 loss_criterion, n_train_tasks):
        '''  Calculate objective based on tasks in meta-batch '''
        # note: it is OK if some tasks appear several times in the meta-batch

        n_tasks_in_mb = len(mb_data_loaders)

        sum_empirical_loss = 0
        sum_intra_task_comp = 0
        correct_count = 0
        sample_count = 0

        # ----------- loop over tasks in meta-batch -----------------------------------#
        for i_task in range(n_tasks_in_mb):

            n_samples = mb_data_loaders[i_task]['n_train_samples']

            # get sample-batch data from current task to calculate the empirical loss estimate:
            batch_data = data_gen.get_next_batch_cyclic(
                mb_iterators[i_task], mb_data_loaders[i_task]['train'])

            # The posterior model corresponding to the task in the batch:
            # post_model = mb_posteriors_models[i_task]
            prior_model.train()

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            # task_complexity = 0
            # ----------- Monte-Carlo loop  -----------------------------------#
            for i_MC in range(n_MC):
                # get batch variables:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Empirical Loss on current task:
                outputs = prior_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
            # end Monte-Carlo loop

            sum_empirical_loss += task_empirical_loss

        # end loop over tasks in meta-batch
        avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss

        return avg_empirical_loss
Пример #11
0
def run_test(model, test_loader, loss_criterion, prm):
    model.eval()
    test_loss = 0
    n_correct = 0
    for batch_data in test_loader:
        inputs, targets = data_gen.get_batch_vars(batch_data, prm)
        batch_size = inputs.shape[0]
        outputs = model(inputs)
        test_loss += (1 / batch_size) * loss_criterion(
            outputs, targets).item()  # sum the mean loss in batch
        n_correct += count_correct(outputs, targets)

    n_test_samples = len(test_loader.dataset)
    n_test_batches = len(test_loader)
    test_loss = test_loss / n_test_batches
    test_acc = n_correct / n_test_samples
    print('\n Standard learning: test loss: {:.4}, test err: {:.3} ( {}/{})\n'.
          format(test_loss, 1 - test_acc, n_correct, n_test_samples))
    return test_acc
Пример #12
0
    def run_test(model, test_loader):
        model.eval()
        test_loss = 0
        n_correct = 0
        for batch_data in test_loader:
            inputs, targets = data_gen.get_batch_vars(batch_data,
                                                      prm,
                                                      is_test=True)
            outputs = model(inputs)
            test_loss += loss_criterion(outputs,
                                        targets)  # sum the mean loss in batch
            n_correct += count_correct(outputs, targets)

        n_test_samples = len(test_loader.dataset)
        n_test_batches = len(test_loader)
        test_loss = test_loss.data[0] / n_test_batches
        test_acc = n_correct / n_test_samples
        print('\nTest set: Average loss: {:.4}, Accuracy: {:.3} ( {}/{})\n'.
              format(test_loss, test_acc, n_correct, n_test_samples))
        return test_acc
def run_test_majority_vote(model, test_loader, loss_criterion, prm, n_votes=9):
    #
    n_test_samples = len(test_loader.dataset)
    n_test_batches = len(test_loader)
    model.eval()
    test_loss = 0
    n_correct = 0
    for batch_data in test_loader:
        inputs, targets = data_gen.get_batch_vars(batch_data,
                                                  prm,
                                                  is_test=True)

        batch_size = inputs.shape[
            0]  # min(prm.test_batch_size, n_test_samples)
        info = data_gen.get_info(prm)
        n_labels = info['n_classes']
        votes = cmn.zeros_gpu((batch_size, n_labels))
        for i_vote in range(n_votes):

            outputs = model(inputs)
            test_loss += loss_criterion(outputs, targets)
            pred = outputs.data.max(
                1, keepdim=True)[1]  # get the index of the max output
            for i_sample in range(batch_size):
                pred_val = pred[i_sample].cpu().numpy()[0]
                votes[i_sample, pred_val] += 1

        majority_pred = votes.max(
            1, keepdim=True)[1]  # find argmax class for each sample
        n_correct += majority_pred.eq(
            targets.data.view_as(majority_pred)).cpu().sum()
    test_loss /= n_test_samples
    test_acc = n_correct / n_test_samples
    info = {
        'test_acc': test_acc,
        'n_correct': n_correct,
        'test_type': 'majority_vote',
        'n_test_samples': n_test_samples,
        'test_loss': get_value(test_loss)
    }
    return info
Пример #14
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        model.train()
        for batch_idx, batch_data in enumerate(train_loader):

            # get batch:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Calculate loss:
            outputs = model(inputs)
            loss = loss_criterion(outputs, targets)

            # Take gradient step:
            grad_step(loss, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, get_value(loss)))
Пример #15
0
    def run_meta_test_learning(task_model, train_loader):

        task_model.train()
        train_loader_iter = iter(train_loader)

        # Gradient steps (training) loop
        for i_grad_step in range(prm.n_meta_test_grad_steps):
            # get batch:
            batch_data = data_gen.get_next_batch_cyclic(
                train_loader_iter, train_loader)
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Calculate empirical loss:
            outputs = task_model(inputs)
            task_objective = loss_criterion(outputs, targets)

            # Take gradient step with the task weights:
            grad_step(task_objective, task_optimizer)

        # end gradient step loop

        return task_model
Пример #16
0
    def run_train_epoch(i_epoch):

        prior_model.train()

        for batch_idx, batch_data in enumerate(data_prior_loader):

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = prior_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

                # Total objective:

            total_objective = task_empirical_loss

            # Take gradient step with the posterior:
            grad_step(total_objective, all_optimizer, lr_schedule, prm.lr,
                      i_epoch)
            log_interval = 20
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    'number meta batch:{} \t avg_empiric_loss:{:.3f} \t batch accuracy:{:.3f}'
                    .format(batch_idx, total_objective, batch_acc))
def run_test_avg_vote(model, test_loader, loss_criterion, prm, n_votes=5):

    n_test_samples = len(test_loader.dataset)
    n_test_batches = len(test_loader)
    model.eval()
    test_loss = 0
    n_correct = 0
    for batch_data in test_loader:
        inputs, targets = data_gen.get_batch_vars(batch_data,
                                                  prm,
                                                  is_test=True)

        batch_size = min(prm.test_batch_size, n_test_samples)
        info = data_gen.get_info(prm)
        n_labels = info['n_classes']
        votes = cmn.zeros_gpu((batch_size, n_labels))
        for i_vote in range(n_votes):

            outputs = model(inputs)
            test_loss += loss_criterion(outputs, targets)
            votes += outputs.data

        majority_pred = votes.max(1, keepdim=True)[1]
        n_correct += majority_pred.eq(
            targets.data.view_as(majority_pred)).cpu().sum()

    test_loss /= n_test_samples
    test_acc = n_correct / n_test_samples
    info = {
        'test_acc': test_acc,
        'n_correct': n_correct,
        'test_type': 'AvgVote',
        'n_test_samples': n_test_samples,
        'test_loss': get_value(test_loss)
    }
    return info
Пример #18
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        train_info = {}
        train_info["task_comp"] = 0.0
        train_info["total_loss"] = 0.0

        cnt = 0
        for batch_idx, batch_data in enumerate(train_loader):
            cnt += 1

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                #hyper_kl = 0 when testing
                curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + prm.task_complex_w * task_complexity

            train_info["task_comp"] += task_complexity.data[0]
            train_info["total_loss"] += total_objective.data[0]

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                write_to_log(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.data[0]) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}, w_kld {:.4}, b_kld {:.4}'
                    .format(task_empirical_loss.data[0],
                            task_complexity.data[0], task_info["w_kld"],
                            task_info["b_kld"]), prm)

        train_info["task_comp"] /= cnt
        train_info["total_loss"] /= cnt
        return train_info
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators,
                  mb_posteriors_models, loss_criterion, n_train_tasks):
    '''  Calculate objective based on tasks in meta-batch '''
    # note: it is OK if some tasks appear several times in the meta-batch

    n_tasks_in_mb = len(mb_data_loaders)

    sum_empirical_loss = 0
    sum_intra_task_comp = 0
    correct_count = 0
    sample_count = 0
    #set_trace()

    # KLD between hyper-posterior and hyper-prior:
    hyper_kl = (1 / (2 * prm.kappa_prior**2)) * net_norm(
        prior_model, p=2)  #net_norm is L2-regularization

    # Hyper-prior term:
    meta_complex_term = get_meta_complexity_term(hyper_kl, prm, n_train_tasks)
    sum_w_kld = 0.0
    sum_b_kld = 0.0

    # ----------- loop over tasks in meta-batch -----------------------------------#
    for i_task in range(n_tasks_in_mb):

        n_samples = mb_data_loaders[i_task]['n_train_samples']

        # get sample-batch data from current task to calculate the empirical loss estimate:
        batch_data = data_gen.get_next_batch_cyclic(
            mb_iterators[i_task], mb_data_loaders[i_task]['train'])

        # The posterior model corresponding to the task in the batch:
        post_model = mb_posteriors_models[i_task]
        post_model.train()

        # Monte-Carlo iterations:
        n_MC = prm.n_MC
        task_empirical_loss = 0
        task_complexity = 0
        # ----------- Monte-Carlo loop  -----------------------------------#
        for i_MC in range(n_MC):
            # get batch variables:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Debug
            # print(targets[0].data[0])  # print first image label
            # import matplotlib.pyplot as plt
            # plt.imshow(inputs[0].cpu().data[0].numpy())  # show first image
            # plt.show()

            # Empirical Loss on current task:
            outputs = post_model(inputs)
            curr_empirical_loss = loss_criterion(outputs, targets)

            correct_count += count_correct(outputs, targets)
            sample_count += inputs.size(0)

            # Intra-task complexity of current task:
            curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective(
                prm,
                prior_model,
                post_model,
                n_samples,
                curr_empirical_loss,
                hyper_kl,
                n_train_tasks=n_train_tasks)

            sum_w_kld += task_info["w_kld"]
            sum_b_kld += task_info["b_kld"]
            task_empirical_loss += (1 / n_MC) * curr_empirical_loss
            task_complexity += (1 / n_MC) * curr_complexity
        # end Monte-Carlo loop

        sum_empirical_loss += task_empirical_loss
        sum_intra_task_comp += task_complexity

    # end loop over tasks in meta-batch
    avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss
    avg_intra_task_comp = (1 / n_tasks_in_mb) * sum_intra_task_comp
    avg_w_kld += (1 / n_tasks_in_mb) * sum_w_kld
    avg_b_kld += (1 / n_tasks_in_mb) * sum_b_kld

    # Approximated total objective:
    total_objective = avg_empirical_loss + prm.task_complex_w * avg_intra_task_comp + prm.meta_complex_w * meta_complex_term

    info = {
        'sample_count': get_value(sample_count),
        'correct_count': get_value(correct_count),
        'avg_empirical_loss': get_value(avg_empirical_loss),
        'avg_intra_task_comp': get_value(avg_intra_task_comp),
        'meta_comp': get_value(meta_complex_term),
        'w_kld': avg_w_kld,
        'b_kld': avg_b_kld
    }
    return total_objective, info
Пример #20
0
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators,
                  mb_posteriors_models, loss_criterion, n_train_tasks):
    '''  Calculate objective based on tasks in meta-batch '''
    # note: it is OK if some tasks appear several times in the meta-batch

    n_tasks_in_mb = len(mb_data_loaders)

    correct_count = 0
    sample_count = 0

    # Hyper-prior term:
    hyper_dvrg = get_hyper_divergnce(prm, prior_model)
    meta_complex_term = get_meta_complexity_term(hyper_dvrg, prm,
                                                 n_train_tasks)

    avg_empiric_loss_per_task = torch.zeros(n_tasks_in_mb, device=prm.device)
    complexity_per_task = torch.zeros(n_tasks_in_mb, device=prm.device)
    n_samples_per_task = torch.zeros(
        n_tasks_in_mb, device=prm.device
    )  # how many sampels there are total in each task (not just in a batch)

    # ----------- loop over tasks in meta-batch -----------------------------------#
    for i_task in range(n_tasks_in_mb):

        n_samples = mb_data_loaders[i_task]['n_train_samples']
        n_samples_per_task[i_task] = n_samples

        # get sample-batch data from current task to calculate the empirical loss estimate:
        batch_data = data_gen.get_next_batch_cyclic(
            mb_iterators[i_task], mb_data_loaders[i_task]['train'])

        # get batch variables:
        inputs, targets = data_gen.get_batch_vars(batch_data, prm)
        batch_size = inputs.shape[0]

        # The posterior model corresponding to the task in the batch:
        post_model = mb_posteriors_models[i_task]
        post_model.train()

        # Monte-Carlo iterations:
        n_MC = prm.n_MC

        avg_empiric_loss = 0.0
        complexity = 0.0

        # Monte-Carlo loop
        for i_MC in range(n_MC):

            # Debug
            # print(targets[0].data[0])  # print first image label
            # import matplotlib.pyplot as plt
            # plt.imshow(inputs[0].cpu().data[0].numpy())  # show first image
            # plt.show()

            # Empirical Loss on current task:
            outputs = post_model(inputs)
            avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                outputs, targets)

            correct_count += count_correct(outputs, targets)  # for print
            sample_count += inputs.size(0)

            # Intra-task complexity of current task:
            # curr_complexity = get_task_complexity(prm, prior_model, post_model,
            #     n_samples, avg_empiric_loss_curr, hyper_dvrg, n_train_tasks=n_train_tasks, noised_prior=True)

            avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr
            # complexity +=  (1 / n_MC) * curr_complexity
        # end Monte-Carlo loop

        complexity = get_task_complexity(prm,
                                         prior_model,
                                         post_model,
                                         n_samples,
                                         avg_empiric_loss,
                                         hyper_dvrg,
                                         n_train_tasks=n_train_tasks,
                                         noised_prior=True)
        avg_empiric_loss_per_task[i_task] = avg_empiric_loss
        complexity_per_task[i_task] = complexity
    # end loop over tasks in meta-batch

    # Approximated total objective:
    if prm.complexity_type == 'Variational_Bayes':
        # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
        #  but its weight in the objective should be considered by how many samples there are total in the task
        total_objective =\
            (avg_empiric_loss_per_task * n_samples_per_task + complexity_per_task).mean() * n_train_tasks + meta_complex_term
        # total_objective = ( avg_empiric_loss_per_task * n_samples_per_task + complexity_per_task).mean() + meta_complex_term

    else:
        total_objective =\
            avg_empiric_loss_per_task.mean() + complexity_per_task.mean() + meta_complex_term

    info = {
        'sample_count': sample_count,
        'correct_count': correct_count,
        'avg_empirical_loss': avg_empiric_loss_per_task.mean().item(),
        'avg_intra_task_comp': complexity_per_task.mean().item(),
        'meta_comp': meta_complex_term.item()
    }
    return total_objective, info
Пример #21
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # get batch data:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)
            batch_size = inputs.shape[0]

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            avg_empiric_loss = 0
            complexity_term = 0

            for i_MC in range(n_MC):

                # Calculate empirical loss:
                outputs = post_model(inputs)
                avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                    outputs, targets)

                # complexity_curr = get_task_complexity(prm, prior_model, post_model,
                #                                            n_train_samples, avg_empiric_loss_curr)

                avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr
                # complexity_term += (1 / n_MC) * complexity_curr

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)
            # end monte-carlo loop

            complexity_term = get_task_complexity(prm, prior_model, post_model,
                                                  n_train_samples,
                                                  avg_empiric_loss)

            # Approximated total objective (for current batch):
            if prm.complexity_type == 'Variational_Bayes':
                # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
                #  but its weight in the objective should be considered by how many samples there are total in the task
                total_objective = avg_empiric_loss * (
                    n_train_samples) + complexity_term
            else:
                total_objective = avg_empiric_loss + complexity_term

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        avg_empiric_loss.item(), complexity_term.item()))