def run_meta_iteration(i_iter):
        # In each meta-iteration we draw a meta-batch of several tasks
        # Then we take a grad step with theta.

        # Generate the data sets of the training-tasks for meta-batch:
        mb_data_loaders = task_generator.create_meta_batch(
            prm, meta_batch_size, meta_split='meta_train')

        # For each task, prepare an iterator to generate training batches:
        mb_iterators = [
            iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size)
        ]

        # Get objective based on tasks in meta-batch:
        total_objective, info = meta_step(prm, model, mb_data_loaders,
                                          mb_iterators, loss_criterion)

        # Take gradient step with the meta-parameters (theta) based on validation data:
        grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr, i_iter)

        # Print status:
        log_interval = 5
        if (i_iter) % log_interval == 0:
            batch_acc = info['correct_count'] / info['sample_count']
            print(
                cmn.status_string(i_iter, n_iterations, 1, 1, batch_acc,
                                  total_objective.data[0]))
Example #2
0
    def run_train_epoch(i_epoch):
        # log_interval = 500

        post_model.train()

        train_iterators = iter(train_loader['train'])

        for batch_idx, batch_data in enumerate(train_loader['train']):

            task_loss, info = get_objective(prior_model, prm, [train_loader],
                                            object.feval, [train_iterators],
                                            [post_model], loss_criterion, 1,
                                            [0])

            grad_step(task_loss[0], post_model, loss_criterion, optimizer, prm,
                      train_iterators, train_loader['train'], lr_schedule,
                      prm.optim_args['lr'], i_epoch)

            # for log_var in post_model.parameters():
            #     if log_var.requires_grad is False:
            #         log_var.data = log_var.data - (i_epoch + 1) * math.log(1 + prm.gamma1)

            # Print status:
            log_interval = 10
            if (batch_idx) % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, prm.n_meta_train_epochs,
                                      batch_idx, n_batches, batch_acc) +
                    ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))
Example #3
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                curr_empirical_loss, curr_complexity = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + task_complexity

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        task_empirical_loss.item(), task_complexity.item()))

                data_objective.append(total_objective.item())
                data_accuracy.append(batch_acc)
                data_emp_loss.append(task_empirical_loss.item())
                data_task_comp.append(task_complexity.item())

            return total_objective.item()
Example #4
0
    def run_train_epoch(i_epoch, i_step = 0):

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [iter(data_loaders[ii]['train']) for ii in range(n_train_tasks)]

        # The task order to take batches from:
        # The meta-batch will be balanced - i.e, each task will appear roughly the same number of times
        # note: if some tasks have less data that other tasks - it may be sampled more than once in an epoch
        task_order = []
        task_ids_list = list(range(n_train_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list
        # Note: this method ensures each training sample in each task is drawn in each epoch.
        # If all the tasks have the same number of sample, then each sample is drawn exactly once in an epoch.

        # random.shuffle(task_ids_list) # --############ TEMP

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        meta_batch_starts = list(range(0, len(task_order), prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start: (meta_batch_start + prm.meta_batch_size)]
            # meta-batch size may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some tasks appear several times in the meta-batch

            mb_data_loaders = [data_loaders[task_id] for task_id in task_ids_in_meta_batch]
            mb_iterators = [train_iterators[task_id] for task_id in task_ids_in_meta_batch]
            mb_posteriors_models = [posteriors_models[task_id] for task_id in task_ids_in_meta_batch]

            # prior_weight_steps = 10000
            # # prior_weight = 1 - math.exp(-i_step/prior_weight_steps)
            # prior_weight = min(i_step / prior_weight_steps, 1.0)
            i_step += 1

            # Get objective based on tasks in meta-batch:
            total_objective, info = get_objective(prior_model, prm, mb_data_loaders,
                                                  mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks)

            # Take gradient step with the shared prior and all tasks' posteriors:
            grad_step(total_objective, all_optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 200
            if i_meta_batch % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(cmn.status_string(i_epoch,  prm.n_meta_train_epochs, i_meta_batch,
                                        n_meta_batches, batch_acc, total_objective.item()) +
                      ' Empiric-Loss: {:.4}\t Task-Comp. {:.4}\t Meta-Comp.: {:.4}\t'.
                      format(info['avg_empirical_loss'], info['avg_intra_task_comp'], info['meta_comp']))
        # end  meta-batches loop
        return i_step
Example #5
0
    def run_train_epoch(i_epoch):

        # # Adjust randomness (eps_std)
        # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule:
        #     if i_epoch > prm.randomness_full_epoch:
        #         eps_std = 1.0
        #     elif i_epoch > prm.randomness_init_epoch:
        #         eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch)
        #     else:
        #         eps_std = 0.0  #  turn off randomness
        #     post_model.set_eps_std(eps_std)

        # post_model.set_eps_std(0.00) # debug

        complexity_term = 0

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # Monte-Carlo iterations:
            empirical_loss = 0
            n_MC = prm.n_MC
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # calculate objective:
                outputs = post_model(inputs)
                empirical_loss_c = loss_criterion(outputs, targets)
                empirical_loss += (1 / n_MC) * empirical_loss_c

            #  complexity/prior term:
            if prior_model:
                empirical_loss, complexity_term = get_bayes_task_objective(
                    prm, prior_model, post_model, n_train_samples,
                    empirical_loss)
            else:
                complexity_term = 0.0

                # Total objective:
            objective = empirical_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 500
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, objective.data[0])
                    + ' Loss: {:.4}\t Comp.: {:.4}'.format(
                        get_value(empirical_loss), get_value(complexity_term)))
    def run_train_epoch(i_epoch):

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [
            iter(train_data_loaders[ii]['train']) for ii in range(n_tasks)
        ]

        # The task order to take batches from:
        task_order = []
        task_ids_list = list(range(n_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list

        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        meta_batch_starts = list(range(0, len(task_order),
                                       prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start:(
                meta_batch_start + prm.meta_batch_size)]
            n_tasks_in_batch = len(
                task_ids_in_meta_batch
            )  # it may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some task appear several times in the meta-batch

            mb_data_loaders = [
                train_data_loaders[task_id]
                for task_id in task_ids_in_meta_batch
            ]
            mb_iterators = [
                train_iterators[task_id] for task_id in task_ids_in_meta_batch
            ]

            # Get objective based on tasks in meta-batch:
            total_objective, info = meta_step(prm, model, mb_data_loaders,
                                              mb_iterators, loss_criterion)

            # Take gradient step with the meta-parameters (theta) based on validation data:
            grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr,
                      i_epoch)

            # Print status:
            log_interval = 200
            if i_meta_batch % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, num_epochs, i_meta_batch,
                                      n_meta_batches, batch_acc,
                                      total_objective.item()))
Example #7
0
    def run_train_epoch(i_epoch, log_mat):

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # get batch data:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            batch_size = inputs.shape[0]

            # Monte-Carlo iterations:
            avg_empiric_loss = torch.zeros(1, device=prm.device)
            n_MC = prm.n_MC

            for i_MC in range(n_MC):

                # calculate objective:
                outputs = post_model(inputs)
                avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                    outputs, targets)
                avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr

            # complexity/prior term:
            if prior_model:
                complexity_term = get_task_complexity(prm, prior_model,
                                                      post_model,
                                                      n_train_samples,
                                                      avg_empiric_loss)
            else:
                complexity_term = torch.zeros(1, device=prm.device)

            # Total objective:
            objective = avg_empiric_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 1000
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, objective.item()) +
                    ' Loss: {:.4}\t Comp.: {:.4}'.format(
                        avg_empiric_loss.item(), complexity_term.item()))

        # End batch loop

        # save results for epochs-figure:
        if figure_flag and (i_epoch % prm.log_figure['interval_epochs'] == 0):
            save_result_for_figure(post_model, prior_model, data_loader, prm,
                                   log_mat, i_epoch)
Example #8
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        model.train()
        for batch_idx, batch_data in enumerate(train_loader):

            # get batch:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Calculate loss:
            outputs = model(inputs)
            loss = loss_criterion(outputs, targets)

            # Take gradient step:
            grad_step(loss, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, get_value(loss)))
Example #9
0
def run_meta_iteration(i_iter, prior_model, task_generator, prm):
    # In each meta-iteration we draw a meta-batch of several tasks
    # Then we take a grad step with prior.

    # Unpack parameters:
    optim_func, optim_args, lr_schedule = \
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)
    meta_batch_size = prm.meta_batch_size
    n_inner_steps = prm.n_inner_steps
    n_meta_iterations = prm.n_meta_train_epochs

    # Generate the data sets of the training-tasks for meta-batch:
    mb_data_loaders = task_generator.create_meta_batch(prm,
                                                       meta_batch_size,
                                                       meta_split='meta_train')

    # For each task, prepare an iterator to generate training batches:
    mb_iterators = [
        iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size)
    ]

    # The posteriors models will adjust to new tasks in eacxh meta-batch
    # Create posterior models for each task:
    posteriors_models = [get_model(prm) for _ in range(meta_batch_size)]
    init_from_prior = True
    if init_from_prior:
        for post_model in posteriors_models:
            post_model.load_state_dict(prior_model.state_dict())

    # Gather all tasks posterior params:
    all_post_param = sum([
        list(posterior_model.parameters())
        for posterior_model in posteriors_models
    ], [])

    # Create optimizer for all parameters (posteriors + prior)
    prior_params = list(prior_model.parameters())
    all_params = all_post_param + prior_params
    all_optimizer = optim_func(all_params, **optim_args)
    # all_optimizer = optim_func(prior_params, **optim_args) ## DeBUG

    test_acc_avg = 0.0
    for i_inner_step in range(n_inner_steps):
        # Get objective based on tasks in meta-batch:
        total_objective, info = get_objective(prior_model, prm,
                                              mb_data_loaders, mb_iterators,
                                              posteriors_models,
                                              loss_criterion,
                                              prm.n_train_tasks)

        # Take gradient step with the meta-parameters (theta) based on validation data:
        grad_step(total_objective, all_optimizer, lr_schedule, prm.lr, i_iter)

        # Print status:
        log_interval = 20
        if (i_inner_step) % log_interval == 0:
            batch_acc = info['correct_count'] / info['sample_count']
            print(
                cmn.status_string(i_iter, n_meta_iterations, i_inner_step,
                                  n_inner_steps, batch_acc,
                                  total_objective.data[0]) +
                ' Empiric-Loss: {:.4}\t Task-Comp. {:.4}\t'.format(
                    info['avg_empirical_loss'], info['avg_intra_task_comp']))

    # Print status = on test set of meta-batch:
    log_interval_eval = 10
    if (i_iter) % log_interval_eval == 0 and i_iter > 0:
        test_acc_avg = run_test(mb_data_loaders, posteriors_models,
                                loss_criterion, prm)
        print('Meta-iter: {} \t Meta-Batch Test Acc: {:1.3}\t'.format(
            i_iter, test_acc_avg))
    # End of inner steps
    return prior_model, posteriors_models, test_acc_avg
Example #10
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        train_info = {}
        train_info["task_comp"] = 0.0
        train_info["total_loss"] = 0.0

        cnt = 0
        for batch_idx, batch_data in enumerate(train_loader):
            cnt += 1

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                #hyper_kl = 0 when testing
                curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + prm.task_complex_w * task_complexity

            train_info["task_comp"] += task_complexity.data[0]
            train_info["total_loss"] += total_objective.data[0]

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                write_to_log(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.data[0]) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}, w_kld {:.4}, b_kld {:.4}'
                    .format(task_empirical_loss.data[0],
                            task_complexity.data[0], task_info["w_kld"],
                            task_info["b_kld"]), prm)

        train_info["task_comp"] /= cnt
        train_info["total_loss"] /= cnt
        return train_info
def run_meta_iteration(i_iter, prior_model, task_generator, prm):
    # In each meta-iteration we draw a meta-batch of several tasks
    # Then we take a grad step with prior.

    # Unpack parameters:
    optim_func, optim_args, lr_schedule = \
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)
    meta_batch_size = prm.meta_batch_size
    n_inner_steps = prm.n_inner_steps
    n_meta_iterations = prm.n_meta_train_epochs

    # Generate the data sets of the training-tasks for meta-batch:
    mb_data_loaders = task_generator.create_meta_batch(prm,
                                                       meta_batch_size,
                                                       meta_split='meta_train')

    # For each task, prepare an iterator to generate training batches:
    mb_iterators = [
        iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size)
    ]

    # The posteriors models will adjust to new tasks in eacxh meta-batch
    # Create posterior models for each task:
    posteriors_models = [get_model(prm) for _ in range(meta_batch_size)]
    init_from_prior = True
    if init_from_prior:
        for post_model in posteriors_models:
            post_model.load_state_dict(prior_model.state_dict())

    # # Gather all tasks posterior params:
    # all_post_param = sum([list(posterior_model.parameters()) for posterior_model in posteriors_models], [])
    #
    # # Create optimizer for all parameters (posteriors + prior)
    # prior_params = list(prior_model.parameters())
    # all_params = all_post_param + prior_params
    # all_optimizer = optim_func(all_params, **optim_args)
    # # all_optimizer = optim_func(prior_params, **optim_args) ## DeBUG
    prior_params = filter(lambda p: p.requires_grad, prior_model.parameters())
    prior_optimizer = optim_func(prior_params, optim_args)
    if prior_optimizer.param_groups[0]['params'][0].grad is None:
        object.grad_init(prm, prior_model, loss_criterion,
                         iter(mb_data_loaders[0]['train']),
                         mb_data_loaders[0]['train'], prior_optimizer)
    # all_params = all_post_param + prior_params
    all_posterior_optimizers = [
        optim_func(
            filter(lambda p: p.requires_grad,
                   posteriors_models[i].parameters()), optim_args)
        for i in range(meta_batch_size)
    ]

    test_acc_avg = 0.0
    for i_inner_step in range(n_inner_steps):
        # Get objective based on tasks in meta-batch:
        # total_objective, info = get_objective(prior_model, prm, mb_data_loaders, mb_iterators,
        #                                       posteriors_models, loss_criterion, prm.n_train_tasks)
        task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders,
                                             object.feval, mb_iterators,
                                             posteriors_models, loss_criterion,
                                             meta_batch_size,
                                             range(meta_batch_size))

        # Take gradient step with the meta-parameters (theta) based on validation data:
        # grad_step(total_objective, all_optimizer, lr_schedule, prm.lr, i_iter)
        prior_optimizer.zero_grad()
        for i_task in range(meta_batch_size):
            if isinstance(task_loss_list[i_task], int):
                continue
            grad_step(task_loss_list[i_task], posteriors_models[i_task],
                      loss_criterion, all_posterior_optimizers[i_task], prm,
                      mb_iterators[i_task], mb_data_loaders[i_task]['train'],
                      lr_schedule, prm.lr, i_iter)
            # if i_meta_batch==n_meta_batches-1:
            prior_get_grad(prior_optimizer, all_posterior_optimizers[i_task])
            # prior_grad_step(prior_model, prior_optimizer, all_posterior_optimizers, prm.meta_batch_size, prm,
            #                 lr_schedule, prm.prior_lr, i_epoch)
        prior_grad_step(prior_optimizer, prm.meta_batch_size, prm,
                        prm.prior_lr_schedule, prm.prior_lr, i_iter)

        # Print status:
        log_interval = 1
        # if (i_inner_step) % log_interval == 0:
        #     batch_acc = info['correct_count'] / info['sample_count']
        #     print(cmn.status_string(i_iter, n_meta_iterations, i_inner_step, n_inner_steps, batch_acc, total_objective.data[0]) +
        #           ' Empiric-Loss: {:.4}\t Task-Comp. {:.4}\t'.
        #           format(info['avg_empirical_loss'], info['avg_intra_task_comp']))
        if (i_inner_step) % log_interval == 0:
            batch_acc = info['correct_count'] / info['sample_count']
            print(
                cmn.status_string(i_iter, prm.n_meta_train_epochs,
                                  i_inner_step, n_inner_steps, batch_acc) +
                ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))

    # Print status = on test set of meta-batch:
    log_interval_eval = 1
    if (i_iter) % log_interval_eval == 0:
        test_acc_avg = run_test(mb_data_loaders, posteriors_models,
                                loss_criterion, prm)
        print('Meta-iter: {} \t Meta-Batch Test Acc: {:1.3}\t'.format(
            i_iter, test_acc_avg))
    # End of inner steps
    return prior_model, posteriors_models, test_acc_avg
Example #12
0
    def run_train_epoch(i_epoch):

        # optim_args['L'] = L-5*i_epoch

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [
            iter(data_loaders[ii]['train']) for ii in range(n_train_tasks)
        ]

        # The task order to take batches from:
        # The meta-batch will be balanced - i.e, each task will appear roughly the same number of times
        # note: if some tasks have less data that other tasks - it may be sampled more than once in an epoch
        task_order = []
        task_ids_list = list(range(n_train_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list
        # Note: this method ensures each training sample in each task is drawn in each epoch.
        # If all the tasks have the same number of sample, then each sample is drawn exactly once in an epoch.

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        # meta_batch_starts = list(range(0, len(task_order), n_train_tasks))
        meta_batch_starts = list(range(0, len(task_order),
                                       prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start:(
                meta_batch_start + prm.meta_batch_size)]
            # meta-batch size may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some tasks appear several times in the meta-batch

            mb_data_loaders = [
                data_loaders[task_id] for task_id in task_ids_in_meta_batch
            ]
            mb_iterators = [
                train_iterators[task_id] for task_id in task_ids_in_meta_batch
            ]
            mb_posteriors_models = [
                posteriors_models[task_id]
                for task_id in task_ids_in_meta_batch
            ]

            #task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                      mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            # prior_optimizer.zero_grad()
            #for i_task in range(n_train_tasks):
            #    if isinstance(task_loss_list[i_task], int):
            #        continue
            #    grad_step(task_loss_list[i_task], posteriors_models[i_task], loss_criterion, all_posterior_optimizers[i_task], prm,
            #              train_iterators[i_task], data_loaders[i_task]['train'], lr_schedule, prm.lr, i_epoch)

            #task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                      mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            # prior_optimizer.zero_grad()
            #for i_task in range(n_train_tasks):
            #    if isinstance(task_loss_list[i_task], int):
            #        continue
            #    grad_step(task_loss_list[i_task], posteriors_models[i_task], loss_criterion, all_posterior_optimizers[i_task], prm,
            #              train_iterators[i_task], data_loaders[i_task]['train'], lr_schedule, prm.lr, i_epoch)

            # Get objective based on tasks in meta-batch:
            task_loss_list, info = get_objective(prior_model, prm,
                                                 mb_data_loaders, object.feval,
                                                 mb_iterators,
                                                 mb_posteriors_models,
                                                 loss_criterion, n_train_tasks,
                                                 task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            prior_optimizer.zero_grad()
            for i_task in range(n_train_tasks):
                if isinstance(task_loss_list[i_task], int):
                    continue
                grad_step(task_loss_list[i_task], posteriors_models[i_task],
                          loss_criterion, all_posterior_optimizers[i_task],
                          prm, train_iterators[i_task],
                          data_loaders[i_task]['train'], lr_schedule, prm.lr,
                          i_epoch)
                # if i_meta_batch==n_meta_batches-1:
                # if i_epoch == prm.n_meta_train_epochs-1:
                prior_get_grad(prior_optimizer,
                               all_posterior_optimizers[i_task])

            # task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                       mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # prior_grad_step(prior_optimizer, prm.meta_batch_size, prm,prm.prior_lr_schedule, prm.prior_lr, i_epoch)
            prior_updates(prior_optimizer, n_train_tasks, prm)
            for post_model in posteriors_models:
                post_model.load_state_dict(prior_model.state_dict())

            # Print status:
            log_interval = 10
            if (i_meta_batch) % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, prm.n_meta_train_epochs,
                                      i_meta_batch, n_meta_batches, batch_acc)
                    +
                    ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))
Example #13
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # get batch data:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)
            batch_size = inputs.shape[0]

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            avg_empiric_loss = 0
            complexity_term = 0

            for i_MC in range(n_MC):

                # Calculate empirical loss:
                outputs = post_model(inputs)
                avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                    outputs, targets)

                # complexity_curr = get_task_complexity(prm, prior_model, post_model,
                #                                            n_train_samples, avg_empiric_loss_curr)

                avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr
                # complexity_term += (1 / n_MC) * complexity_curr

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)
            # end monte-carlo loop

            complexity_term = get_task_complexity(prm, prior_model, post_model,
                                                  n_train_samples,
                                                  avg_empiric_loss)

            # Approximated total objective (for current batch):
            if prm.complexity_type == 'Variational_Bayes':
                # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
                #  but its weight in the objective should be considered by how many samples there are total in the task
                total_objective = avg_empiric_loss * (
                    n_train_samples) + complexity_term
            else:
                total_objective = avg_empiric_loss + complexity_term

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        avg_empiric_loss.item(), complexity_term.item()))