# Run standard learning for each task and average the parameters:
    avg_param_vec = None
    for i_task in range(n_train_tasks):
        print('Learning train-task {} out of {}'.format(i_task+1, n_train_tasks))
        data_loader = train_data_loaders[i_task]
        test_err, curr_model = learn_single_standard.run_learning(data_loader, prm, verbose=0)
        if i_task == 0:
            avg_param_vec = parameters_to_vector(curr_model.parameters()) * (1 / n_train_tasks)
        else:
            avg_param_vec += parameters_to_vector(curr_model.parameters()) * (1 / n_train_tasks)

    avg_model = deterministic_models.get_model(prm)
    vector_to_parameters(avg_param_vec, avg_model.parameters())

    # create the prior model:
    prior_model = stochastic_models.get_model(prm)
    prior_layers_list = [layer for layer in prior_model.modules() if isinstance(layer, StochasticLayer)]
    avg_model_layers_list = [layer for layer in avg_model.modules()
                             if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear)]
    assert len(avg_model_layers_list)==len(prior_layers_list), "lists not equal"

    for i_layer, prior_layer in enumerate(prior_layers_list):
        if hasattr(prior_layer, 'w'):
            prior_layer.w['log_var'] = torch.nn.Parameter(zeros_gpu(1))
            prior_layer.w['mean'] = avg_model_layers_list[i_layer].weight
        if hasattr(prior_layer, 'b'):
            prior_layer.b['log_var'] = torch.nn.Parameter(zeros_gpu(1))
            prior_layer.b['mean'] = avg_model_layers_list[i_layer].bias


    # save learned prior:
예제 #2
0
        save_model_state(prior_model, save_path)
        write_to_log('Trained prior saved in ' + save_path, prm)
    else:
        # In this case we observe new tasks generated from the task-distribution in each meta-iteration.
        write_to_log('---- Infinite train tasks - New training tasks are '
                     'drawn from tasks distribution in each iteration...', prm)

        # Meta-training to learn meta-prior (theta params):
        prior_model = meta_train_Bayes_infinite_tasks.run_meta_learning(task_generator, prm)


elif prm.mode == 'LoadMetaModel':

    # Loads  previously training prior.
    # First, create the model:
    prior_model = get_model(prm)
    prm.load_model_path = '/hdd/shiwei/meta_learning _example/PriorMetaLearning/saved/ShuffledPixels100_TasksN/log 2019-05-23 14:36:30/1/model.pt'
    # prm.load_model_path = '/hdd/shiwei/meta_learning _example/PriorMetaLearning/saved/PermutedLabels_TasksN/log 2019-04-18 12:40:53/5/model.pt'
    # prm.load_model_path = '/hdd/shiwei/meta_learning _example/PriorMetaLearning/saved/model.pt'
    # Then load the weights:
    load_model_state(prior_model, prm.load_model_path)
    write_to_log('Pre-trained  prior loaded from ' + prm.load_model_path, prm)
else:
    raise ValueError('Invalid mode')

# -------------------------------------------------------------------------------------------
# Generate the data sets of the test tasks:
# -------------------------------------------------------------------------------------------

n_test_tasks = prm.n_test_tasks
예제 #3
0
def run_learning(task_data, prior_model, prm, init_from_prior=True, verbose=1):

    # -------------------------------------------------------------------------------------------
    #  Setting-up
    # -------------------------------------------------------------------------------------------
    # Unpack parameters:
    optim_func, optim_args, lr_schedule =\
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)

    # Create posterior model for the new task:
    post_model = get_model(prm)

    if init_from_prior:
        post_model.load_state_dict(prior_model.state_dict())

        # prior_model_dict = prior_model.state_dict()
        # post_model_dict = post_model.state_dict()
        #
        # # filter out unnecessary keys:
        # prior_model_dict = {k: v for k, v in prior_model_dict.items() if '_log_var' in k or '_mu' in k}
        # # overwrite entries in the existing state dict:
        # post_model_dict.update(prior_model_dict)
        #
        # # #  load the new state dict
        # post_model.load_state_dict(post_model_dict)

        # add_noise_to_model(post_model, prm.kappa_factor)

    # The data-sets of the new task:
    train_loader = task_data['train']
    test_loader = task_data['test']
    n_train_samples = len(train_loader.dataset)
    n_batches = len(train_loader)

    #  Get optimizer:
    optimizer = optim_func(post_model.parameters(), **optim_args)

    # -------------------------------------------------------------------------------------------
    #  Training epoch  function
    # -------------------------------------------------------------------------------------------

    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                curr_empirical_loss, curr_complexity = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + task_complexity

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        task_empirical_loss.item(), task_complexity.item()))

                data_objective.append(total_objective.item())
                data_accuracy.append(batch_acc)
                data_emp_loss.append(task_empirical_loss.item())
                data_task_comp.append(task_complexity.item())

            return total_objective.item()

    # -----------------------------------------------------------------------------------------------------------#
    # Update Log file
    if verbose == 1:
        write_to_log(
            'Total number of steps: {}'.format(n_batches *
                                               prm.n_meta_test_epochs), prm)

    # -------------------------------------------------------------------------------------------
    #  Run epochs
    # -------------------------------------------------------------------------------------------
    start_time = timeit.default_timer()

    data_objective = []
    data_accuracy = []
    data_emp_loss = []
    data_task_comp = []
    # Training loop:
    for i_epoch in range(prm.n_meta_test_epochs):
        test_bound = run_train_epoch(i_epoch)

    with open(
            os.path.join(prm.result_dir, 'run_test_data_prior_bound_data.pkl'),
            'wb') as f:
        pickle.dump(
            {
                'data_objective': data_objective,
                "data_accuracy": data_accuracy,
                'data_emp_loss': data_emp_loss,
                'data_task_comp': data_task_comp
            }, f)

    # Test:
    test_acc, test_loss = run_test_Bayes(post_model, test_loader,
                                         loss_criterion, prm)

    stop_time = timeit.default_timer()
    cmn.write_final_result(test_acc,
                           stop_time - start_time,
                           prm,
                           result_name=prm.test_type,
                           verbose=verbose)

    test_err = 1 - test_acc
    return test_err, test_loss, test_bound, post_model
예제 #4
0
def run_learning(task_data, prior_model, prm, init_from_prior=True, verbose=1):

    # prm.optim_func, prm.optim_args = optim.EntropySGD, {'llr':0.01, 'lr':0.1, 'momentum':0.9, 'damp':0, 'weight_decay':1e-3, 'nesterov':True,
    #                  'L':20, 'eps':1e-3, 'g0':1e-4, 'g1':1e-3}

    # -------------------------------------------------------------------------------------------
    #  Setting-up
    # -------------------------------------------------------------------------------------------
    # Unpack parameters:
    # prm.optim_args['llr'] = 0.1
    # prm.optim_args['L'] = 20
    # # prm.optim_args['weight_decay'] = 1e-3
    # # prm.optim_args['g1'] = 0
    # prm.optim_args['g0'] = 1e-4
    optim_func, optim_args, lr_schedule =\
        prm.optim_func, prm.optim_args, prm.lr_schedule_test

    # prm.optim_func, prm.optim_args = optim.Adam, {'lr': prm.lr}  # 'weight_decay': 1e-4

    # lr_schedule = {'decay_factor': 0.1, 'decay_epochs': [15, 20]}

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)

    # Create posterior model for the new task:
    post_model = get_model(prm)

    if init_from_prior:
        post_model.load_state_dict(prior_model.state_dict())

        # prior_model_dict = prior_model.state_dict()
        # post_model_dict = post_model.state_dict()
        #
        # # filter out unnecessary keys:
        # prior_model_dict = {k: v for k, v in prior_model_dict.items() if '_log_var' in k or '_mu' in k}
        # # overwrite entries in the existing state dict:
        # post_model_dict.update(prior_model_dict)
        #
        # # #  load the new state dict
        # post_model.load_state_dict(post_model_dict)

        # add_noise_to_model(post_model, prm.kappa_factor)

    # The data-sets of the new task:
    train_loader = task_data
    test_loader = task_data['test']
    # n_train_samples = len(train_loader['train'].dataset)
    n_batches = len(train_loader)

    #  Get optimizer:
    optimizer = optim_func(
        filter(lambda p: p.requires_grad, post_model.parameters()), optim_args)

    # optimizer = optim_func(filter(lambda p: p.requires_grad, post_model.parameters()), optim_args['lr'])

    # -------------------------------------------------------------------------------------------
    #  Training epoch  function
    # -------------------------------------------------------------------------------------------

    def run_train_epoch(i_epoch):
        # log_interval = 500

        post_model.train()

        train_iterators = iter(train_loader['train'])

        for batch_idx, batch_data in enumerate(train_loader['train']):

            task_loss, info = get_objective(prior_model, prm, [train_loader],
                                            object.feval, [train_iterators],
                                            [post_model], loss_criterion, 1,
                                            [0])

            grad_step(task_loss[0], post_model, loss_criterion, optimizer, prm,
                      train_iterators, train_loader['train'], lr_schedule,
                      prm.optim_args['lr'], i_epoch)

            # for log_var in post_model.parameters():
            #     if log_var.requires_grad is False:
            #         log_var.data = log_var.data - (i_epoch + 1) * math.log(1 + prm.gamma1)

            # Print status:
            log_interval = 10
            if (batch_idx) % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, prm.n_meta_train_epochs,
                                      batch_idx, n_batches, batch_acc) +
                    ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))

    # -----------------------------------------------------------------------------------------------------------#
    # Update Log file
    if verbose == 1:
        write_to_log(
            'Total number of steps: {}'.format(n_batches *
                                               prm.n_meta_test_epochs), prm)

    # -------------------------------------------------------------------------------------------
    #  Run epochs
    # -------------------------------------------------------------------------------------------
    start_time = timeit.default_timer()

    # Training loop:
    for i_epoch in range(prm.n_meta_test_epochs):
        run_train_epoch(i_epoch)

    # Test:
    test_acc, test_loss = run_test_Bayes(post_model, test_loader,
                                         loss_criterion, prm)

    stop_time = timeit.default_timer()
    cmn.write_final_result(test_acc,
                           stop_time - start_time,
                           prm,
                           result_name=prm.test_type,
                           verbose=verbose)

    test_err = 1 - test_acc
    return test_err, post_model
예제 #5
0
def run_meta_learning(data_loaders, prm):

    # -------------------------------------------------------------------------------------------
    #  Setting-up
    # -------------------------------------------------------------------------------------------
    # Unpack parameters:
    optim_func, optim_args, lr_schedule =\
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)

    n_train_tasks = len(data_loaders)

    # Create a 'dummy' model to generate the set of parameters of the shared prior:
    prior_model = get_model(prm)

    # Create posterior models for each task:
    posteriors_models = [
        transfer_weights(prior_model, get_model(prm))
        for _ in range(n_train_tasks)
    ]

    # Gather all tasks posterior params:
    # all_post_param = sum([list(posterior_model.parameters()) for posterior_model in posteriors_models], [])

    # Create optimizer for all parameters (posteriors + prior)
    prior_params = filter(lambda p: p.requires_grad, prior_model.parameters())
    prior_optimizer = optim_func(prior_params, optim_args)
    object.grad_init(prm, prior_model, loss_criterion,
                     iter(data_loaders[0]['train']), data_loaders[0]['train'],
                     prior_optimizer)
    # all_params = all_post_param + prior_params
    all_posterior_optimizers = [
        optim_func(
            filter(lambda p: p.requires_grad,
                   posteriors_models[i].parameters()), optim_args)
        for i in range(n_train_tasks)
    ]

    # number of sample-batches in each task:
    n_batch_list = [len(data_loader['train']) for data_loader in data_loaders]

    n_batches_per_task = np.max(n_batch_list)

    L = prm.optim_args['L']

    # -------------------------------------------------------------------------------------------
    #  Training epoch  function
    # -------------------------------------------------------------------------------------------
    def run_train_epoch(i_epoch):

        # optim_args['L'] = L-5*i_epoch

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [
            iter(data_loaders[ii]['train']) for ii in range(n_train_tasks)
        ]

        # The task order to take batches from:
        # The meta-batch will be balanced - i.e, each task will appear roughly the same number of times
        # note: if some tasks have less data that other tasks - it may be sampled more than once in an epoch
        task_order = []
        task_ids_list = list(range(n_train_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list
        # Note: this method ensures each training sample in each task is drawn in each epoch.
        # If all the tasks have the same number of sample, then each sample is drawn exactly once in an epoch.

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        # meta_batch_starts = list(range(0, len(task_order), n_train_tasks))
        meta_batch_starts = list(range(0, len(task_order),
                                       prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start:(
                meta_batch_start + prm.meta_batch_size)]
            # meta-batch size may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some tasks appear several times in the meta-batch

            mb_data_loaders = [
                data_loaders[task_id] for task_id in task_ids_in_meta_batch
            ]
            mb_iterators = [
                train_iterators[task_id] for task_id in task_ids_in_meta_batch
            ]
            mb_posteriors_models = [
                posteriors_models[task_id]
                for task_id in task_ids_in_meta_batch
            ]

            #task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                      mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            # prior_optimizer.zero_grad()
            #for i_task in range(n_train_tasks):
            #    if isinstance(task_loss_list[i_task], int):
            #        continue
            #    grad_step(task_loss_list[i_task], posteriors_models[i_task], loss_criterion, all_posterior_optimizers[i_task], prm,
            #              train_iterators[i_task], data_loaders[i_task]['train'], lr_schedule, prm.lr, i_epoch)

            #task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                      mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            # prior_optimizer.zero_grad()
            #for i_task in range(n_train_tasks):
            #    if isinstance(task_loss_list[i_task], int):
            #        continue
            #    grad_step(task_loss_list[i_task], posteriors_models[i_task], loss_criterion, all_posterior_optimizers[i_task], prm,
            #              train_iterators[i_task], data_loaders[i_task]['train'], lr_schedule, prm.lr, i_epoch)

            # Get objective based on tasks in meta-batch:
            task_loss_list, info = get_objective(prior_model, prm,
                                                 mb_data_loaders, object.feval,
                                                 mb_iterators,
                                                 mb_posteriors_models,
                                                 loss_criterion, n_train_tasks,
                                                 task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            prior_optimizer.zero_grad()
            for i_task in range(n_train_tasks):
                if isinstance(task_loss_list[i_task], int):
                    continue
                grad_step(task_loss_list[i_task], posteriors_models[i_task],
                          loss_criterion, all_posterior_optimizers[i_task],
                          prm, train_iterators[i_task],
                          data_loaders[i_task]['train'], lr_schedule, prm.lr,
                          i_epoch)
                # if i_meta_batch==n_meta_batches-1:
                # if i_epoch == prm.n_meta_train_epochs-1:
                prior_get_grad(prior_optimizer,
                               all_posterior_optimizers[i_task])

            # task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                       mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # prior_grad_step(prior_optimizer, prm.meta_batch_size, prm,prm.prior_lr_schedule, prm.prior_lr, i_epoch)
            prior_updates(prior_optimizer, n_train_tasks, prm)
            for post_model in posteriors_models:
                post_model.load_state_dict(prior_model.state_dict())

            # Print status:
            log_interval = 10
            if (i_meta_batch) % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, prm.n_meta_train_epochs,
                                      i_meta_batch, n_meta_batches, batch_acc)
                    +
                    ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))

        # for i in range(20):
        #     prior_grad_step(prior_optimizer, prm.meta_batch_size, prm, lr_schedule, prm.prior_lr, i_epoch)
        # end  meta-batches loop

    # end run_epoch()

    # -------------------------------------------------------------------------------------------
    #  Test evaluation function -
    # Evaluate the mean loss on samples from the test sets of the training tasks
    # --------------------------------------------------------------------------------------------
    def run_test():
        test_acc_avg = 0.0
        n_tests = 0
        for i_task in range(n_train_tasks):
            model = posteriors_models[i_task]
            test_loader = data_loaders[i_task]['test']
            if len(test_loader) > 0:
                test_acc, test_loss = run_test_Bayes(model, test_loader,
                                                     loss_criterion, prm)
                n_tests += 1
                test_acc_avg += test_acc

                n_test_samples = len(test_loader.dataset)

                write_to_log(
                    'Train Task {}, Test set: {} -  Average loss: {:.4}, Accuracy: {:.3} (of {} samples)\n'
                    .format(i_task, prm.test_type, test_loss, test_acc,
                            n_test_samples), prm)
            else:
                print('Train Task {}, Test set: {} - No test data'.format(
                    i_task, prm.test_type))

        if n_tests > 0:
            test_acc_avg /= n_tests
        return test_acc_avg

    # -----------------------------------------------------------------------------------------------------------#
    # Main script
    # -----------------------------------------------------------------------------------------------------------#

    # Update Log file

    write_to_log(cmn.get_model_string(prior_model), prm)
    write_to_log('---- Meta-Training set: {0} tasks'.format(len(data_loaders)),
                 prm)

    # -------------------------------------------------------------------------------------------
    #  Run epochs
    # -------------------------------------------------------------------------------------------
    start_time = timeit.default_timer()

    # Training loop:
    for i_epoch in range(prm.n_meta_train_epochs):
        # if (i_epoch+1) % 50 == 0:
        #     prm.lr = prm.lr/2
        # for post_model in posteriors_models:
        #     post_model.load_state_dict(prior_model.state_dict())
        run_train_epoch(i_epoch)
        # for post_model in posteriors_models:
        #     post_model.load_state_dict(prior_model.state_dict())

    # prior_update(prior_optimizer,prm.meta_batch_size,prm)

    stop_time = timeit.default_timer()

    # Test:
    test_acc_avg = run_test()

    # Update Log file:
    cmn.write_final_result(test_acc_avg,
                           stop_time - start_time,
                           prm,
                           result_name=prm.test_type)

    # Return learned prior:
    return prior_model
예제 #6
0
def run_learning(task_data, prior_model, prm, init_from_prior=True, verbose=1):

    # -------------------------------------------------------------------------------------------
    #  Setting-up
    # -------------------------------------------------------------------------------------------
    # Unpack parameters:
    optim_func, optim_args, lr_schedule =\
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_func(prm)

    # Create posterior model for the new task:
    post_model = get_model(prm)

    if init_from_prior:
        post_model.load_state_dict(prior_model.state_dict())

        # prior_model_dict = prior_model.state_dict()
        # post_model_dict = post_model.state_dict()
        #
        # # filter out unnecessary keys:
        # prior_model_dict = {k: v for k, v in prior_model_dict.items() if '_log_var' in k or '_mu' in k}
        # # overwrite entries in the existing state dict:
        # post_model_dict.update(prior_model_dict)
        #
        # # #  load the new state dict
        # post_model.load_state_dict(post_model_dict)

        # add_noise_to_model(post_model, prm.kappa_factor)

    # The data-sets of the new task:
    train_loader = task_data['train']
    test_loader = task_data['test']
    n_train_samples = len(train_loader.dataset)
    n_batches = len(train_loader)

    #  Get optimizer:
    optimizer = optim_func(post_model.parameters(), **optim_args)

    # -------------------------------------------------------------------------------------------
    #  Training epoch  function
    # -------------------------------------------------------------------------------------------

    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # get batch data:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)
            batch_size = inputs.shape[0]

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            avg_empiric_loss = 0
            complexity_term = 0

            for i_MC in range(n_MC):

                # Calculate empirical loss:
                outputs = post_model(inputs)
                avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                    outputs, targets)

                # complexity_curr = get_task_complexity(prm, prior_model, post_model,
                #                                            n_train_samples, avg_empiric_loss_curr)

                avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr
                # complexity_term += (1 / n_MC) * complexity_curr

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)
            # end monte-carlo loop

            complexity_term = get_task_complexity(prm, prior_model, post_model,
                                                  n_train_samples,
                                                  avg_empiric_loss)

            # Approximated total objective (for current batch):
            if prm.complexity_type == 'Variational_Bayes':
                # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
                #  but its weight in the objective should be considered by how many samples there are total in the task
                total_objective = avg_empiric_loss * (
                    n_train_samples) + complexity_term
            else:
                total_objective = avg_empiric_loss + complexity_term

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        avg_empiric_loss.item(), complexity_term.item()))
        # end batch loop

    # end run_train_epoch()

    # -----------------------------------------------------------------------------------------------------------#
    # Update Log file
    if verbose == 1:
        write_to_log(
            'Total number of steps: {}'.format(n_batches *
                                               prm.n_meta_test_epochs), prm)

    # -------------------------------------------------------------------------------------------
    #  Run epochs
    # -------------------------------------------------------------------------------------------
    start_time = timeit.default_timer()

    # Training loop:
    for i_epoch in range(prm.n_meta_test_epochs):
        run_train_epoch(i_epoch)

    # Test:
    test_acc, test_loss = run_eval_Bayes(post_model, test_loader, prm)

    stop_time = timeit.default_timer()
    cmn.write_final_result(test_acc,
                           stop_time - start_time,
                           prm,
                           result_name=prm.test_type,
                           verbose=verbose)

    test_err = 1 - test_acc
    return test_err, post_model
예제 #7
0
def run_learning(data_loader, prm, prior_model=None, init_from_prior=True, verbose=1):

    # -------------------------------------------------------------------------------------------
    #  Setting-up
    # -------------------------------------------------------------------------------------------

    # Unpack parameters:
    optim_func, optim_args, lr_schedule = \
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)

    train_loader = data_loader['train']
    test_loader = data_loader['test']
    n_batches = len(train_loader)
    n_train_samples = data_loader['n_train_samples']

    # get model:
    if prior_model and init_from_prior:
        # init from prior model:
        post_model = deepcopy(prior_model)
    else:
        post_model = get_model(prm)

    # post_model.set_eps_std(0.0) # DEBUG: turn off randomness

    #  Get optimizer:
    optimizer = optim_func(post_model.parameters(), **optim_args)

    # -------------------------------------------------------------------------------------------
    #  Training epoch  function
    # -------------------------------------------------------------------------------------------

    def run_train_epoch(i_epoch):

        # # Adjust randomness (eps_std)
        # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule:
        #     if i_epoch > prm.randomness_full_epoch:
        #         eps_std = 1.0
        #     elif i_epoch > prm.randomness_init_epoch:
        #         eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch)
        #     else:
        #         eps_std = 0.0  #  turn off randomness
        #     post_model.set_eps_std(eps_std)

        # post_model.set_eps_std(0.00) # debug

        complexity_term = 0

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # Monte-Carlo iterations:
            empirical_loss = 0
            n_MC = prm.n_MC

            # get batch:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)
            
            for i_MC in range(n_MC):           

                # calculate objective:
                outputs = post_model(inputs)
                empirical_loss_c = loss_criterion(outputs, targets)
                empirical_loss += (1 / n_MC) * empirical_loss_c

            #  complexity/prior term:
            if prior_model:
                empirical_loss, complexity_term = get_bayes_task_objective(
                    prm, prior_model, post_model, n_train_samples, empirical_loss)
            else:
                complexity_term = 0.0

                # Total objective:
            objective = empirical_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 500
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, get_value(objective)) +
                      ' Loss: {:.4}\t Comp.: {:.4}'.format(get_value(empirical_loss), get_value(complexity_term)))
    # -------------------------------------------------------------------------------------------
    #  Main Script
    # -------------------------------------------------------------------------------------------


    #  Update Log file
    update_file = not verbose == 0
    cmn.write_to_log(cmn.get_model_string(post_model), prm, update_file=update_file)
    cmn.write_to_log('Total number of steps: {}'.format(n_batches * prm.num_epochs), prm, update_file=update_file)
    cmn.write_to_log('Number of training samples: {}'.format(data_loader['n_train_samples']), prm, update_file=update_file)

    start_time = timeit.default_timer()

    # Run training epochs:
    for i_epoch in range(prm.num_epochs):
        run_train_epoch(i_epoch)

    # Test:
    test_acc, test_loss = run_test_Bayes(post_model, test_loader, loss_criterion, prm)

    stop_time = timeit.default_timer()
    cmn.write_final_result(test_acc, stop_time - start_time, prm, result_name=prm.test_type)

    test_err = 1 - test_acc
    return test_err, post_model
train_dataset, test_dataset, info = load_pretrain_dataset()
train_loader = DataLoader(train_dataset,
                          batch_size=bz,
                          shuffle=True,
                          num_workers=12,
                          pin_memory=True)
print(len(train_dataset), len(test_dataset))
test_loader = DataLoader(test_dataset,
                         batch_size=bz,
                         shuffle=True,
                         num_workers=12,
                         pin_memory=True)
save_dir = "pretrained_cifar100"
os.makedirs(save_dir, exist_ok=True)

net = get_model(prm)

#load_model_state(net, save_dir + "/" + "epoch-2-acc0.277.pth")

#debug()
print(net)
net = net.cuda()

loss_fn = nn.CrossEntropyLoss()

learning_rate = 1e-3
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
for epoch in range(epoch_num):
    cnt = 0
    for imgs, ys in train_loader:
        cnt += 1