def run_meta_iteration(i_iter):
        # In each meta-iteration we draw a meta-batch of several tasks
        # Then we take a grad step with theta.

        # Generate the data sets of the training-tasks for meta-batch:
        mb_data_loaders = task_generator.create_meta_batch(
            prm, meta_batch_size, meta_split='meta_train')

        # For each task, prepare an iterator to generate training batches:
        mb_iterators = [
            iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size)
        ]

        # Get objective based on tasks in meta-batch:
        total_objective, info = meta_step(prm, model, mb_data_loaders,
                                          mb_iterators, loss_criterion)

        # Take gradient step with the meta-parameters (theta) based on validation data:
        grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr, i_iter)

        # Print status:
        log_interval = 5
        if (i_iter) % log_interval == 0:
            batch_acc = info['correct_count'] / info['sample_count']
            print(
                cmn.status_string(i_iter, n_iterations, 1, 1, batch_acc,
                                  total_objective.data[0]))
예제 #2
0
    def run_train_epoch(i_epoch):
        # log_interval = 500

        post_model.train()

        train_iterators = iter(train_loader['train'])

        for batch_idx, batch_data in enumerate(train_loader['train']):

            task_loss, info = get_objective(prior_model, prm, [train_loader],
                                            object.feval, [train_iterators],
                                            [post_model], loss_criterion, 1,
                                            [0])

            grad_step(task_loss[0], post_model, loss_criterion, optimizer, prm,
                      train_iterators, train_loader['train'], lr_schedule,
                      prm.optim_args['lr'], i_epoch)

            # for log_var in post_model.parameters():
            #     if log_var.requires_grad is False:
            #         log_var.data = log_var.data - (i_epoch + 1) * math.log(1 + prm.gamma1)

            # Print status:
            log_interval = 10
            if (batch_idx) % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, prm.n_meta_train_epochs,
                                      batch_idx, n_batches, batch_acc) +
                    ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))
    def run_train_epoch(i_epoch, i_step=0):

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [
            iter(data_loaders[ii]['data_prior']) for ii in range(n_train_tasks)
        ]

        # The task order to take batches from:
        # The meta-batch will be balanced - i.e, each task will appear roughly the same number of times
        # note: if some tasks have less data that other tasks - it may be sampled more than once in an epoch
        task_order = []
        task_ids_list = list(range(n_train_tasks))
        # create -- n_batches_per_task * n_train_tasks -- number list
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list
        # Note: this method ensures each training sample in each task is drawn in each epoch.
        # If all the tasks have the same number of sample, then each sample is drawn exactly once in an epoch.

        # random.shuffle(task_ids_list) # -- ############ -- TEMP

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        # - maximum -- prm.meta_batch_size tasks -- in each meta batch
        # total -- len(task_order) / prm.meta_batch_size -- tasks
        meta_batch_starts = list(range(0, len(task_order),
                                       prm.meta_batch_size))
        # totally update -- len(meta_batch_starts) -- times
        n_meta_batches = len(meta_batch_starts)

        for i_meta_batch in range(n_meta_batches):

            # only select prm.meta_batch_size 5 tasks in each meta batch
            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start:(
                meta_batch_start + prm.meta_batch_size)]
            # meta-batch size may be less than prm.meta_batch_size at the last one
            # note: it is OK if some tasks appear several times in the meta-batch

            mb_data_loaders = [
                data_loaders[task_id] for task_id in task_ids_in_meta_batch
            ]
            mb_iterators = [
                train_iterators[task_id] for task_id in task_ids_in_meta_batch
            ]

            i_step += 1

            # Get objective based on tasks in meta-batch:
            empirical_error = get_risk(prior_model, prm, mb_data_loaders,
                                       mb_iterators, loss_criterion,
                                       n_train_tasks)

            grad_step(empirical_error, all_optimizer, lr_schedule, prm.lr,
                      i_epoch)
            log_interval = 20
            if i_meta_batch % log_interval == 0:
                print('number meta batch:{} \t avg_empiric_loss:{:.3f}'.format(
                    i_meta_batch, empirical_error))
예제 #4
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                curr_empirical_loss, curr_complexity = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + task_complexity

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        task_empirical_loss.item(), task_complexity.item()))

                data_objective.append(total_objective.item())
                data_accuracy.append(batch_acc)
                data_emp_loss.append(task_empirical_loss.item())
                data_task_comp.append(task_complexity.item())

            return total_objective.item()
예제 #5
0
    def run_train_epoch(i_epoch, i_step = 0):

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [iter(data_loaders[ii]['train']) for ii in range(n_train_tasks)]

        # The task order to take batches from:
        # The meta-batch will be balanced - i.e, each task will appear roughly the same number of times
        # note: if some tasks have less data that other tasks - it may be sampled more than once in an epoch
        task_order = []
        task_ids_list = list(range(n_train_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list
        # Note: this method ensures each training sample in each task is drawn in each epoch.
        # If all the tasks have the same number of sample, then each sample is drawn exactly once in an epoch.

        # random.shuffle(task_ids_list) # --############ TEMP

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        meta_batch_starts = list(range(0, len(task_order), prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start: (meta_batch_start + prm.meta_batch_size)]
            # meta-batch size may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some tasks appear several times in the meta-batch

            mb_data_loaders = [data_loaders[task_id] for task_id in task_ids_in_meta_batch]
            mb_iterators = [train_iterators[task_id] for task_id in task_ids_in_meta_batch]
            mb_posteriors_models = [posteriors_models[task_id] for task_id in task_ids_in_meta_batch]

            # prior_weight_steps = 10000
            # # prior_weight = 1 - math.exp(-i_step/prior_weight_steps)
            # prior_weight = min(i_step / prior_weight_steps, 1.0)
            i_step += 1

            # Get objective based on tasks in meta-batch:
            total_objective, info = get_objective(prior_model, prm, mb_data_loaders,
                                                  mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks)

            # Take gradient step with the shared prior and all tasks' posteriors:
            grad_step(total_objective, all_optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 200
            if i_meta_batch % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(cmn.status_string(i_epoch,  prm.n_meta_train_epochs, i_meta_batch,
                                        n_meta_batches, batch_acc, total_objective.item()) +
                      ' Empiric-Loss: {:.4}\t Task-Comp. {:.4}\t Meta-Comp.: {:.4}\t'.
                      format(info['avg_empirical_loss'], info['avg_intra_task_comp'], info['meta_comp']))
        # end  meta-batches loop
        return i_step
예제 #6
0
    def run_train_epoch(i_epoch):

        # # Adjust randomness (eps_std)
        # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule:
        #     if i_epoch > prm.randomness_full_epoch:
        #         eps_std = 1.0
        #     elif i_epoch > prm.randomness_init_epoch:
        #         eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch)
        #     else:
        #         eps_std = 0.0  #  turn off randomness
        #     post_model.set_eps_std(eps_std)

        # post_model.set_eps_std(0.00) # debug

        complexity_term = 0

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # Monte-Carlo iterations:
            empirical_loss = 0
            n_MC = prm.n_MC
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # calculate objective:
                outputs = post_model(inputs)
                empirical_loss_c = loss_criterion(outputs, targets)
                empirical_loss += (1 / n_MC) * empirical_loss_c

            #  complexity/prior term:
            if prior_model:
                empirical_loss, complexity_term = get_bayes_task_objective(
                    prm, prior_model, post_model, n_train_samples,
                    empirical_loss)
            else:
                complexity_term = 0.0

                # Total objective:
            objective = empirical_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 500
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, objective.data[0])
                    + ' Loss: {:.4}\t Comp.: {:.4}'.format(
                        get_value(empirical_loss), get_value(complexity_term)))
    def run_train_epoch(i_epoch):

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [
            iter(train_data_loaders[ii]['train']) for ii in range(n_tasks)
        ]

        # The task order to take batches from:
        task_order = []
        task_ids_list = list(range(n_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list

        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        meta_batch_starts = list(range(0, len(task_order),
                                       prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start:(
                meta_batch_start + prm.meta_batch_size)]
            n_tasks_in_batch = len(
                task_ids_in_meta_batch
            )  # it may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some task appear several times in the meta-batch

            mb_data_loaders = [
                train_data_loaders[task_id]
                for task_id in task_ids_in_meta_batch
            ]
            mb_iterators = [
                train_iterators[task_id] for task_id in task_ids_in_meta_batch
            ]

            # Get objective based on tasks in meta-batch:
            total_objective, info = meta_step(prm, model, mb_data_loaders,
                                              mb_iterators, loss_criterion)

            # Take gradient step with the meta-parameters (theta) based on validation data:
            grad_step(total_objective, meta_optimizer, lr_schedule, prm.lr,
                      i_epoch)

            # Print status:
            log_interval = 200
            if i_meta_batch % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, num_epochs, i_meta_batch,
                                      n_meta_batches, batch_acc,
                                      total_objective.item()))
예제 #8
0
    def run_train_epoch(i_epoch, log_mat):

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # get batch data:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            batch_size = inputs.shape[0]

            # Monte-Carlo iterations:
            avg_empiric_loss = torch.zeros(1, device=prm.device)
            n_MC = prm.n_MC

            for i_MC in range(n_MC):

                # calculate objective:
                outputs = post_model(inputs)
                avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                    outputs, targets)
                avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr

            # complexity/prior term:
            if prior_model:
                complexity_term = get_task_complexity(prm, prior_model,
                                                      post_model,
                                                      n_train_samples,
                                                      avg_empiric_loss)
            else:
                complexity_term = torch.zeros(1, device=prm.device)

            # Total objective:
            objective = avg_empiric_loss + complexity_term

            # Take gradient step:
            grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            log_interval = 1000
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, objective.item()) +
                    ' Loss: {:.4}\t Comp.: {:.4}'.format(
                        avg_empiric_loss.item(), complexity_term.item()))

        # End batch loop

        # save results for epochs-figure:
        if figure_flag and (i_epoch % prm.log_figure['interval_epochs'] == 0):
            save_result_for_figure(post_model, prior_model, data_loader, prm,
                                   log_mat, i_epoch)
예제 #9
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        model.train()
        for batch_idx, batch_data in enumerate(train_loader):

            # get batch:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Calculate loss:
            outputs = model(inputs)
            loss = loss_criterion(outputs, targets)

            # Take gradient step:
            grad_step(loss, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_rate(outputs, targets)
                print(
                    cmn.status_string(i_epoch, prm.num_epochs, batch_idx,
                                      n_batches, batch_acc, get_value(loss)))
예제 #10
0
    def run_meta_test_learning(task_model, train_loader):

        task_model.train()
        train_loader_iter = iter(train_loader)

        # Gradient steps (training) loop
        for i_grad_step in range(prm.n_meta_test_grad_steps):
            # get batch:
            batch_data = data_gen.get_next_batch_cyclic(
                train_loader_iter, train_loader)
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)

            # Calculate empirical loss:
            outputs = task_model(inputs)
            task_objective = loss_criterion(outputs, targets)

            # Take gradient step with the task weights:
            grad_step(task_objective, task_optimizer)

        # end gradient step loop

        return task_model
예제 #11
0
    def run_train_epoch(i_epoch):

        prior_model.train()

        for batch_idx, batch_data in enumerate(data_prior_loader):

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = prior_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

                # Total objective:

            total_objective = task_empirical_loss

            # Take gradient step with the posterior:
            grad_step(total_objective, all_optimizer, lr_schedule, prm.lr,
                      i_epoch)
            log_interval = 20
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    'number meta batch:{} \t avg_empiric_loss:{:.3f} \t batch accuracy:{:.3f}'
                    .format(batch_idx, total_objective, batch_acc))
예제 #12
0
    async def adapt(self, train_futures, first_order=None):
        # if first_order is None:
        #     first_order = self.first_order
        # Loop over the number of steps of adaptation 循环调整的步数
        params = None
        # await后面调用future对象,中断当前程序直到得到 futures 的返回值
        # 等待 futures 计算完成,再进行计算 reinforce_loss

        params_show_maml_trpo = self.policy.state_dict()

        for train_future in train_futures:
            """

            """
            train_loss = reinforce_loss(self.original_policy,
                                        await train_future)
            lr = 1e-3
            self.original_policy.train()
            optimizer = optim.Adam(self.original_policy.parameters(), lr)
            # Take gradient step:
            # 计算梯度 已经
            grad_step(train_loss, optimizer)

            """
            原来的算法
            """
            # inner_loss = reinforce_loss(self.policy,
            #                             await futures)
            # # 计算更新后参数,好像不传输到网络中?  self.policy.state_dict()仍然为参数
            # params = self.policy.update_params(inner_loss,
            #                                    params=params,
            #                                    step_size=self.fast_lr,
            #                                    first_order=first_order)

            # params_show_maml_trpo_test = self.policy.state_dict()

        return train_loss
예제 #13
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        for batch_idx, batch_data in enumerate(train_loader):

            # get batch data:
            inputs, targets = data_gen.get_batch_vars(batch_data, prm)
            batch_size = inputs.shape[0]

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            avg_empiric_loss = 0
            complexity_term = 0

            for i_MC in range(n_MC):

                # Calculate empirical loss:
                outputs = post_model(inputs)
                avg_empiric_loss_curr = (1 / batch_size) * loss_criterion(
                    outputs, targets)

                # complexity_curr = get_task_complexity(prm, prior_model, post_model,
                #                                            n_train_samples, avg_empiric_loss_curr)

                avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr
                # complexity_term += (1 / n_MC) * complexity_curr

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)
            # end monte-carlo loop

            complexity_term = get_task_complexity(prm, prior_model, post_model,
                                                  n_train_samples,
                                                  avg_empiric_loss)

            # Approximated total objective (for current batch):
            if prm.complexity_type == 'Variational_Bayes':
                # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
                #  but its weight in the objective should be considered by how many samples there are total in the task
                total_objective = avg_empiric_loss * (
                    n_train_samples) + complexity_term
            else:
                total_objective = avg_empiric_loss + complexity_term

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                print(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.item()) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format(
                        avg_empiric_loss.item(), complexity_term.item()))
예제 #14
0
    def run_train_epoch(i_epoch):

        # optim_args['L'] = L-5*i_epoch

        # For each task, prepare an iterator to generate training batches:
        train_iterators = [
            iter(data_loaders[ii]['train']) for ii in range(n_train_tasks)
        ]

        # The task order to take batches from:
        # The meta-batch will be balanced - i.e, each task will appear roughly the same number of times
        # note: if some tasks have less data that other tasks - it may be sampled more than once in an epoch
        task_order = []
        task_ids_list = list(range(n_train_tasks))
        for i_batch in range(n_batches_per_task):
            random.shuffle(task_ids_list)
            task_order += task_ids_list
        # Note: this method ensures each training sample in each task is drawn in each epoch.
        # If all the tasks have the same number of sample, then each sample is drawn exactly once in an epoch.

        # ----------- meta-batches loop (batches of tasks) -----------------------------------#
        # each meta-batch includes several tasks
        # we take a grad step with theta after each meta-batch
        # meta_batch_starts = list(range(0, len(task_order), n_train_tasks))
        meta_batch_starts = list(range(0, len(task_order),
                                       prm.meta_batch_size))
        n_meta_batches = len(meta_batch_starts)

        for i_meta_batch in range(n_meta_batches):

            meta_batch_start = meta_batch_starts[i_meta_batch]
            task_ids_in_meta_batch = task_order[meta_batch_start:(
                meta_batch_start + prm.meta_batch_size)]
            # meta-batch size may be less than  prm.meta_batch_size at the last one
            # note: it is OK if some tasks appear several times in the meta-batch

            mb_data_loaders = [
                data_loaders[task_id] for task_id in task_ids_in_meta_batch
            ]
            mb_iterators = [
                train_iterators[task_id] for task_id in task_ids_in_meta_batch
            ]
            mb_posteriors_models = [
                posteriors_models[task_id]
                for task_id in task_ids_in_meta_batch
            ]

            #task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                      mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            # prior_optimizer.zero_grad()
            #for i_task in range(n_train_tasks):
            #    if isinstance(task_loss_list[i_task], int):
            #        continue
            #    grad_step(task_loss_list[i_task], posteriors_models[i_task], loss_criterion, all_posterior_optimizers[i_task], prm,
            #              train_iterators[i_task], data_loaders[i_task]['train'], lr_schedule, prm.lr, i_epoch)

            #task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                      mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            # prior_optimizer.zero_grad()
            #for i_task in range(n_train_tasks):
            #    if isinstance(task_loss_list[i_task], int):
            #        continue
            #    grad_step(task_loss_list[i_task], posteriors_models[i_task], loss_criterion, all_posterior_optimizers[i_task], prm,
            #              train_iterators[i_task], data_loaders[i_task]['train'], lr_schedule, prm.lr, i_epoch)

            # Get objective based on tasks in meta-batch:
            task_loss_list, info = get_objective(prior_model, prm,
                                                 mb_data_loaders, object.feval,
                                                 mb_iterators,
                                                 mb_posteriors_models,
                                                 loss_criterion, n_train_tasks,
                                                 task_ids_in_meta_batch)

            # Take gradient step with the shared prior and all tasks' posteriors:
            # for i_task in range(n_train_tasks):
            prior_optimizer.zero_grad()
            for i_task in range(n_train_tasks):
                if isinstance(task_loss_list[i_task], int):
                    continue
                grad_step(task_loss_list[i_task], posteriors_models[i_task],
                          loss_criterion, all_posterior_optimizers[i_task],
                          prm, train_iterators[i_task],
                          data_loaders[i_task]['train'], lr_schedule, prm.lr,
                          i_epoch)
                # if i_meta_batch==n_meta_batches-1:
                # if i_epoch == prm.n_meta_train_epochs-1:
                prior_get_grad(prior_optimizer,
                               all_posterior_optimizers[i_task])

            # task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders, object.feval,
            #                                       mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks, task_ids_in_meta_batch)

            # prior_grad_step(prior_optimizer, prm.meta_batch_size, prm,prm.prior_lr_schedule, prm.prior_lr, i_epoch)
            prior_updates(prior_optimizer, n_train_tasks, prm)
            for post_model in posteriors_models:
                post_model.load_state_dict(prior_model.state_dict())

            # Print status:
            log_interval = 10
            if (i_meta_batch) % log_interval == 0:
                batch_acc = info['correct_count'] / info['sample_count']
                print(
                    cmn.status_string(i_epoch, prm.n_meta_train_epochs,
                                      i_meta_batch, n_meta_batches, batch_acc)
                    +
                    ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))
예제 #15
0
def run_test(task,
             prior_policy,
             post_policy,
             baseline,
             args,
             env_name,
             env_kwargs,
             batch_size,
             observation_space,
             action_space,
             n_train_tasks,
             num_test_batches=10):
    optim_func, optim_args, lr_schedule =\
        args.optim_func, args.optim_args, args.lr_schedule
    #  Get optimizer:
    optimizer = optim_func(post_policy.parameters(), **optim_args)

    # *******************************************************************
    # Train: post_policy
    # *******************************************************************
    for batch in range(num_test_batches):
        # Hyper-prior term:
        # 计算超先验与超后验的散度
        hyper_dvrg = get_hyper_divergnce(kappa_prior=args.kappa_prior,
                                         kappa_post=args.kappa_post,
                                         divergence_type=args.divergence_type,
                                         device=args.device,
                                         prior_model=prior_policy)
        # 根据 超散度 hyper_dvrg 计算对应的 meta项  传参方式也可以直接安顺序传递
        meta_complex_term = get_meta_complexity_term(
            hyper_kl=hyper_dvrg,
            delta=args.delta,
            complexity_type=args.complexity_type,
            n_train_tasks=n_train_tasks)

        sampler = SampleTest(env_name,
                             env_kwargs,
                             batch_size=batch_size,
                             observation_space=observation_space,
                             action_space=action_space,
                             policy=post_policy,
                             baseline=baseline,
                             seed=args.seed,
                             prior_policy=prior_policy,
                             task=task)
        # calculate empirical error for per task
        loss_per_task, avg_reward, last_reward, train_episodes = sampler.sample(
        )

        complexity = get_task_complexity(delta=args.delta,
                                         complexity_type=args.complexity_type,
                                         device=args.device,
                                         divergence_type=args.divergence_type,
                                         kappa_post=args.kappa_post,
                                         prior_model=prior_policy,
                                         post_model=post_policy,
                                         n_samples=batch_size,
                                         avg_empiric_loss=loss_per_task,
                                         hyper_dvrg=hyper_dvrg,
                                         n_train_tasks=n_train_tasks,
                                         noised_prior=True)

        if args.complexity_type == 'Variational_Bayes':
            # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
            #  but its weight in the objective should be considered by how many samples there are total in the task
            n_train_samples = 1
            total_objective = loss_per_task * (n_train_samples) + complexity
        else:
            # 该项类似于 PAC Bayes
            total_objective = loss_per_task + complexity

        # Take gradient step with the shared prior and all tasks' posteriors:
        grad_step(total_objective, optimizer, lr_schedule, args.lr)

    # *******************************************************************
    # Test: post_policy
    # *******************************************************************

    # test_acc, test_loss = run_eval_max_posterior(post_model, test_loader, prm)

    sampler = SampleTest(env_name,
                         env_kwargs,
                         batch_size=batch_size,
                         observation_space=observation_space,
                         action_space=action_space,
                         policy=post_policy,
                         baseline=baseline,
                         seed=args.seed,
                         task=task)
    # calculate empirical error for per task
    test_loss_per_task, test_avg_reward, test_last_reward, train_episodes = sampler.sample(
    )

    Data_post_Trajectory = train_episodes[0].observations.numpy()
    task = task[0]
    task = task['goal']
    plt.plot(Data_post_Trajectory[:, 0, 0], Data_post_Trajectory[:, 0, 1])
    plt.plot(task[0], task[1], 'g^')
    plt.savefig('Trajectories.pdf')
    plt.show()
    return test_loss_per_task, test_avg_reward, test_last_reward
예제 #16
0
def run_meta_iteration(i_iter, prior_model, task_generator, prm):
    # In each meta-iteration we draw a meta-batch of several tasks
    # Then we take a grad step with prior.

    # Unpack parameters:
    optim_func, optim_args, lr_schedule = \
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)
    meta_batch_size = prm.meta_batch_size
    n_inner_steps = prm.n_inner_steps
    n_meta_iterations = prm.n_meta_train_epochs

    # Generate the data sets of the training-tasks for meta-batch:
    mb_data_loaders = task_generator.create_meta_batch(prm,
                                                       meta_batch_size,
                                                       meta_split='meta_train')

    # For each task, prepare an iterator to generate training batches:
    mb_iterators = [
        iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size)
    ]

    # The posteriors models will adjust to new tasks in eacxh meta-batch
    # Create posterior models for each task:
    posteriors_models = [get_model(prm) for _ in range(meta_batch_size)]
    init_from_prior = True
    if init_from_prior:
        for post_model in posteriors_models:
            post_model.load_state_dict(prior_model.state_dict())

    # Gather all tasks posterior params:
    all_post_param = sum([
        list(posterior_model.parameters())
        for posterior_model in posteriors_models
    ], [])

    # Create optimizer for all parameters (posteriors + prior)
    prior_params = list(prior_model.parameters())
    all_params = all_post_param + prior_params
    all_optimizer = optim_func(all_params, **optim_args)
    # all_optimizer = optim_func(prior_params, **optim_args) ## DeBUG

    test_acc_avg = 0.0
    for i_inner_step in range(n_inner_steps):
        # Get objective based on tasks in meta-batch:
        total_objective, info = get_objective(prior_model, prm,
                                              mb_data_loaders, mb_iterators,
                                              posteriors_models,
                                              loss_criterion,
                                              prm.n_train_tasks)

        # Take gradient step with the meta-parameters (theta) based on validation data:
        grad_step(total_objective, all_optimizer, lr_schedule, prm.lr, i_iter)

        # Print status:
        log_interval = 20
        if (i_inner_step) % log_interval == 0:
            batch_acc = info['correct_count'] / info['sample_count']
            print(
                cmn.status_string(i_iter, n_meta_iterations, i_inner_step,
                                  n_inner_steps, batch_acc,
                                  total_objective.data[0]) +
                ' Empiric-Loss: {:.4}\t Task-Comp. {:.4}\t'.format(
                    info['avg_empirical_loss'], info['avg_intra_task_comp']))

    # Print status = on test set of meta-batch:
    log_interval_eval = 10
    if (i_iter) % log_interval_eval == 0 and i_iter > 0:
        test_acc_avg = run_test(mb_data_loaders, posteriors_models,
                                loss_criterion, prm)
        print('Meta-iter: {} \t Meta-Batch Test Acc: {:1.3}\t'.format(
            i_iter, test_acc_avg))
    # End of inner steps
    return prior_model, posteriors_models, test_acc_avg
예제 #17
0
    def sample(self,
               index,
               num_steps=1,
               fast_lr=0.5,
               gamma=0.95,
               gae_lambda=1.0,
               device='cpu'):
        """
        基于初始策略采样训练轨迹,并基于REINFORCE损失调整策略
        内循环中,梯度更新使用`first_order=True`,因其仅用于采样轨迹,而不是优化
        Sample the training trajectories with the initial policy and adapt the
        policy to the task, based on the REINFORCE loss computed on the
        training trajectories. The gradient update in the fast adaptation uses
        `first_order=True` no matter if the second order version of MAML is
        applied since this is only used for sampling trajectories, and not
        for optimization.
        """
        """
        训练阶段:
            采样训练轨迹数据 train_episodes,计算loss,更新原有网络参数
            采样验证轨迹数据 valid_episodes
        MAML 内部循环更新num_steps次 inner loop / fast adaptation
        """
        # 此处参数设置为 None,调用 OrderDict() 参数
        """
        ******************************************************************
        """
        # params = None

        # params_show_multi_task_sampler = self.policy.state_dict()

        for step in range(num_steps):
            # 获取该batch中所有的轨迹数据,将数据保存至 train_episodes
            train_episodes = self.create_episodes(gamma=gamma,
                                                  gae_lambda=gae_lambda,
                                                  device=device)
            train_episodes.log('_enqueueAt', datetime.now(timezone.utc))
            # QKFIX: Deep copy the episodes before sending them to their
            # respective queues, to avoid a race condition. This issue would
            # cause the policy pi = policy(observations) to be miscomputed for
            # some timesteps, which in turns makes the loss explode.
            self.train_queue.put((index, step, deepcopy(train_episodes)))
            """
                计算 reinforce loss, 更新网络参数 params
            """
            # 多线程程序中,安全使用可变对象
            # with + lock:保证每次只有一个线程执行下面代码块
            # with 语句会在这个代码块执行前自动获取锁,在执行结束后自动释放锁
            with self.policy_lock:
                """
                ******************************************************************
                """
                loss = reinforce_loss(self.policy, train_episodes)
                lr = 1e-3
                self.policy.train()
                optimizer = optim.Adam(self.policy.parameters(), lr)
                # Take gradient step:
                # 计算梯度 已经
                grad_step(loss, optimizer)
                # params = self.policy.update_params(loss,
                #                                    params=params,
                #                                    step_size=fast_lr,
                #                                    first_order=True)
                """
                ******************************************************************
                """
                # params_show_multi_task_sampler_test = self.policy.state_dict()

        # Sample the validation trajectories with the adapted policy
        valid_episodes = self.create_episodes(gamma=gamma,
                                              gae_lambda=gae_lambda,
                                              device=device)
        valid_episodes.log('_enqueueAt', datetime.now(timezone.utc))
        self.valid_queue.put((index, None, deepcopy(valid_episodes)))
예제 #18
0
    def run_train_epoch(i_epoch):
        log_interval = 500

        post_model.train()

        train_info = {}
        train_info["task_comp"] = 0.0
        train_info["total_loss"] = 0.0

        cnt = 0
        for batch_idx, batch_data in enumerate(train_loader):
            cnt += 1

            correct_count = 0
            sample_count = 0

            # Monte-Carlo iterations:
            n_MC = prm.n_MC
            task_empirical_loss = 0
            task_complexity = 0
            for i_MC in range(n_MC):
                # get batch:
                inputs, targets = data_gen.get_batch_vars(batch_data, prm)

                # Calculate empirical loss:
                outputs = post_model(inputs)
                curr_empirical_loss = loss_criterion(outputs, targets)

                #hyper_kl = 0 when testing
                curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective(
                    prm,
                    prior_model,
                    post_model,
                    n_train_samples,
                    curr_empirical_loss,
                    noised_prior=False)

                task_empirical_loss += (1 / n_MC) * curr_empirical_loss
                task_complexity += (1 / n_MC) * curr_complexity

                correct_count += count_correct(outputs, targets)
                sample_count += inputs.size(0)

            # Total objective:

            total_objective = task_empirical_loss + prm.task_complex_w * task_complexity

            train_info["task_comp"] += task_complexity.data[0]
            train_info["total_loss"] += total_objective.data[0]

            # Take gradient step with the posterior:
            grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch)

            # Print status:
            if batch_idx % log_interval == 0:
                batch_acc = correct_count / sample_count
                write_to_log(
                    cmn.status_string(i_epoch, prm.n_meta_test_epochs,
                                      batch_idx, n_batches, batch_acc,
                                      total_objective.data[0]) +
                    ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}, w_kld {:.4}, b_kld {:.4}'
                    .format(task_empirical_loss.data[0],
                            task_complexity.data[0], task_info["w_kld"],
                            task_info["b_kld"]), prm)

        train_info["task_comp"] /= cnt
        train_info["total_loss"] /= cnt
        return train_info
def run_meta_iteration(i_iter, prior_model, task_generator, prm):
    # In each meta-iteration we draw a meta-batch of several tasks
    # Then we take a grad step with prior.

    # Unpack parameters:
    optim_func, optim_args, lr_schedule = \
        prm.optim_func, prm.optim_args, prm.lr_schedule

    # Loss criterion
    loss_criterion = get_loss_criterion(prm.loss_type)
    meta_batch_size = prm.meta_batch_size
    n_inner_steps = prm.n_inner_steps
    n_meta_iterations = prm.n_meta_train_epochs

    # Generate the data sets of the training-tasks for meta-batch:
    mb_data_loaders = task_generator.create_meta_batch(prm,
                                                       meta_batch_size,
                                                       meta_split='meta_train')

    # For each task, prepare an iterator to generate training batches:
    mb_iterators = [
        iter(mb_data_loaders[ii]['train']) for ii in range(meta_batch_size)
    ]

    # The posteriors models will adjust to new tasks in eacxh meta-batch
    # Create posterior models for each task:
    posteriors_models = [get_model(prm) for _ in range(meta_batch_size)]
    init_from_prior = True
    if init_from_prior:
        for post_model in posteriors_models:
            post_model.load_state_dict(prior_model.state_dict())

    # # Gather all tasks posterior params:
    # all_post_param = sum([list(posterior_model.parameters()) for posterior_model in posteriors_models], [])
    #
    # # Create optimizer for all parameters (posteriors + prior)
    # prior_params = list(prior_model.parameters())
    # all_params = all_post_param + prior_params
    # all_optimizer = optim_func(all_params, **optim_args)
    # # all_optimizer = optim_func(prior_params, **optim_args) ## DeBUG
    prior_params = filter(lambda p: p.requires_grad, prior_model.parameters())
    prior_optimizer = optim_func(prior_params, optim_args)
    if prior_optimizer.param_groups[0]['params'][0].grad is None:
        object.grad_init(prm, prior_model, loss_criterion,
                         iter(mb_data_loaders[0]['train']),
                         mb_data_loaders[0]['train'], prior_optimizer)
    # all_params = all_post_param + prior_params
    all_posterior_optimizers = [
        optim_func(
            filter(lambda p: p.requires_grad,
                   posteriors_models[i].parameters()), optim_args)
        for i in range(meta_batch_size)
    ]

    test_acc_avg = 0.0
    for i_inner_step in range(n_inner_steps):
        # Get objective based on tasks in meta-batch:
        # total_objective, info = get_objective(prior_model, prm, mb_data_loaders, mb_iterators,
        #                                       posteriors_models, loss_criterion, prm.n_train_tasks)
        task_loss_list, info = get_objective(prior_model, prm, mb_data_loaders,
                                             object.feval, mb_iterators,
                                             posteriors_models, loss_criterion,
                                             meta_batch_size,
                                             range(meta_batch_size))

        # Take gradient step with the meta-parameters (theta) based on validation data:
        # grad_step(total_objective, all_optimizer, lr_schedule, prm.lr, i_iter)
        prior_optimizer.zero_grad()
        for i_task in range(meta_batch_size):
            if isinstance(task_loss_list[i_task], int):
                continue
            grad_step(task_loss_list[i_task], posteriors_models[i_task],
                      loss_criterion, all_posterior_optimizers[i_task], prm,
                      mb_iterators[i_task], mb_data_loaders[i_task]['train'],
                      lr_schedule, prm.lr, i_iter)
            # if i_meta_batch==n_meta_batches-1:
            prior_get_grad(prior_optimizer, all_posterior_optimizers[i_task])
            # prior_grad_step(prior_model, prior_optimizer, all_posterior_optimizers, prm.meta_batch_size, prm,
            #                 lr_schedule, prm.prior_lr, i_epoch)
        prior_grad_step(prior_optimizer, prm.meta_batch_size, prm,
                        prm.prior_lr_schedule, prm.prior_lr, i_iter)

        # Print status:
        log_interval = 1
        # if (i_inner_step) % log_interval == 0:
        #     batch_acc = info['correct_count'] / info['sample_count']
        #     print(cmn.status_string(i_iter, n_meta_iterations, i_inner_step, n_inner_steps, batch_acc, total_objective.data[0]) +
        #           ' Empiric-Loss: {:.4}\t Task-Comp. {:.4}\t'.
        #           format(info['avg_empirical_loss'], info['avg_intra_task_comp']))
        if (i_inner_step) % log_interval == 0:
            batch_acc = info['correct_count'] / info['sample_count']
            print(
                cmn.status_string(i_iter, prm.n_meta_train_epochs,
                                  i_inner_step, n_inner_steps, batch_acc) +
                ' Empiric-Loss: {:.4f}'.format(info['avg_empirical_loss']))

    # Print status = on test set of meta-batch:
    log_interval_eval = 1
    if (i_iter) % log_interval_eval == 0:
        test_acc_avg = run_test(mb_data_loaders, posteriors_models,
                                loss_criterion, prm)
        print('Meta-iter: {} \t Meta-Batch Test Acc: {:1.3}\t'.format(
            i_iter, test_acc_avg))
    # End of inner steps
    return prior_model, posteriors_models, test_acc_avg
def main(args, prior_policy=None, init_from_prior=True):

    # *******************************************************************
    # config log filename
    #    'r': read;  'w': write
    # *******************************************************************
    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    if args.output_folder is not None:
        # 如果没有文件,则创建文件地址
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
        # 文件夹地址与文件名
        policy_filename = os.path.join(args.output_folder,
                                       'policy_2d_PAC_Bayes.th')
        config_filename = os.path.join(args.output_folder,
                                       'config_2d_PAC_Bayes.json')

        # with open(config_filename, 'w') as f:
        #     config.update(vars(args))
        #     json.dump(config, f, indent=2)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    env = gym.make(config['env-name'], **config['env-kwargs'])
    # 待测试
    env.seed(args.seed)
    env.close()
    """
    ************************************************************
    新增加的参数:用于获取环境的动作观测空间大小,一便生成随机贝叶斯网络
    output_size = reduce(mul, env.action_space.shape, 1)
    input_size = reduce(mul, env.observation_space.shape, 1)
    ************************************************************
    """
    observation_space = env.observation_space
    action_space = env.action_space
    args.output_size = reduce(mul, env.action_space.shape, 1)
    args.input_size = reduce(mul, env.observation_space.shape, 1)
    """
    ************************************************************
    新增加的模型:随机网络
    device = ('cuda' if (torch.cuda.is_available()
                   and args.use_cuda) else 'cpu')
    log_var_init = {'mean': -10, 'std': 0.1}
    ************************************************************
    """
    if prior_policy and init_from_prior:
        # init from prior model:
        # deepcopy函数:复制并作为一个单独的个体存在;copy函数:复制原有对象,随着原有对象改变而改变
        prior_policy = deepcopy(prior_policy).to(args.device)
    else:
        # 否则直接加载新模型
        prior_policy = get_policy_for_env(args.device,
                                          args.log_var_init,
                                          env,
                                          hidden_sizes=config['hidden-sizes'])

    # 数据无需拷贝,即可使用
    # prior_policy.share_memory()
    """
    ************************************************************
    策略 prior model 与 post model 以及对应的参数 param
        prior_policy  posteriors_policies
        prior_params  all_post_param
        all_params
    ************************************************************
    """
    num_tasks = config['meta-batch-size']
    batch_size = config['fast-batch-size']

    # Unpack parameters:
    # 提取参数 优化方法 优化参数 学习率等
    optim_func, optim_args, lr_schedule =\
        args.optim_func, args.optim_args, args.lr_schedule

    posteriors_policies = [
        get_policy_for_env(args.device,
                           args.log_var_init,
                           env,
                           hidden_sizes=config['hidden-sizes'])
        for _ in range(num_tasks)
    ]
    all_post_param = sum([
        list(posterior_policy.parameters())
        for posterior_policy in posteriors_policies
    ], [])

    # Create optimizer for all parameters (posteriors + prior)
    # 对所有参数 包括 prior 以及 posterior 创建优化器
    prior_params = list(prior_policy.parameters())
    all_params = all_post_param + prior_params
    all_optimizer = optim_func(all_params, **optim_args)
    """生成固定的 tasks
        随机数问题尚未解决,可重复性不行
    """
    # Baseline
    baseline = LinearFeatureBaseline(get_input_size(env))

    # 生成 'meta-batch-size' 任务
    # for task in enumerate(tasks):
    tasks = env.unwrapped.sample_tasks(num_tasks)

    # meta-batch-size:Number of tasks in each batch of tasks
    # 一个batch中任务的个数,此处使用 PAC-Bayes方法,因此任务类型以及数量是固定
    # 也即在2D导航任务中,目标值固定,每次采用不同轨迹进行训练
    # tasks = sampler.sample_tasks(num_tasks=config['meta-batch-size'])

    avg_empiric_loss_per_task = torch.zeros(num_tasks, device=args.device)
    avg_reward_per_task = torch.zeros(num_tasks, device=args.device)
    complexity_per_task = torch.zeros(num_tasks, device=args.device)
    # 此参数针对不同任务有不同的训练数量的情况
    n_samples_per_task = torch.zeros(num_tasks, device=args.device)

    Info_avg_reward = []
    Info_total_objective = []
    Info_last_reward = []
    Info_train_trajectories = []

    # 训练的次数 num-batches 个 batch
    for batch in range(config['num-batches']):
        print(batch)

        # params_show_train = prior_policy.state_dict()

        # Hyper-prior term:
        # 计算超先验与超后验的散度
        hyper_dvrg = get_hyper_divergnce(kappa_prior=args.kappa_prior,
                                         kappa_post=args.kappa_post,
                                         divergence_type=args.divergence_type,
                                         device=args.device,
                                         prior_model=prior_policy)
        # 根据 超散度 hyper_dvrg 计算对应的 meta项  传参方式也可以直接安顺序传递
        meta_complex_term = get_meta_complexity_term(
            hyper_kl=hyper_dvrg,
            delta=args.delta,
            complexity_type=args.complexity_type,
            n_train_tasks=num_tasks)

        for i_task in range(num_tasks):
            sampler = SampleTest(config['env-name'],
                                 env_kwargs=config['env-kwargs'],
                                 batch_size=batch_size,
                                 observation_space=observation_space,
                                 action_space=action_space,
                                 policy=posteriors_policies[i_task],
                                 baseline=baseline,
                                 seed=args.seed,
                                 prior_policy=prior_policy,
                                 task=tasks[i_task])
            # calculate empirical error for per task
            loss_per_task, avg_reward, last_reward, train_episodes = sampler.sample(
            )

            complexity = get_task_complexity(
                delta=args.delta,
                complexity_type=args.complexity_type,
                device=args.device,
                divergence_type=args.divergence_type,
                kappa_post=args.kappa_post,
                prior_model=prior_policy,
                post_model=posteriors_policies[i_task],
                n_samples=batch_size,
                avg_empiric_loss=loss_per_task,
                hyper_dvrg=hyper_dvrg,
                n_train_tasks=num_tasks,
                noised_prior=True)

            avg_empiric_loss_per_task[i_task] = loss_per_task
            avg_reward_per_task[i_task] = avg_reward
            complexity_per_task[i_task] = complexity
            n_samples_per_task[i_task] = batch_size

        # Approximated total objective:
        if args.complexity_type == 'Variational_Bayes':
            # note that avg_empiric_loss_per_task is estimated by an average over batch samples,
            #  but its weight in the objective should be considered by how many samples there are total in the task
            total_objective = \
                (avg_empiric_loss_per_task * n_samples_per_task + complexity_per_task).mean() * num_tasks \
                + meta_complex_term
            # total_objective = ( avg_empiric_loss_per_task * n_samples_per_task
            # + complexity_per_task).mean() + meta_complex_term

        else:
            total_objective = \
                avg_empiric_loss_per_task.mean() + complexity_per_task.mean() + meta_complex_term

        # Take gradient step with the shared prior and all tasks' posteriors:
        grad_step(total_objective, all_optimizer, lr_schedule, args.lr)

        Info_avg_reward.append(avg_reward_per_task.mean())
        Info_total_objective.append(total_objective)
        Info_last_reward.append(last_reward)

    # *******************************************************************
    # Save policy
    # *******************************************************************
    # 将模型参数保存至 policy_filename 中的 python.th
    if args.output_folder is not None:
        with open(policy_filename, 'wb') as f:
            # 保存网络中的参数,f 为路径
            torch.save(prior_policy.state_dict(), f)

    # *******************************************************************
    # Test
    # learned policy   : prior_policy
    # saved parameters : 'policy_2d_PAC_Bayes.th'
    # *******************************************************************
    env_name = config['env-name'],
    env_kwargs = config['env-kwargs']
    test_num = 10

    Info_test_loss = []
    Info_test_avg_reward = []
    Info_test_last_reward = []

    for test_batch in range(test_num):
        # 生成新任务,训练并进行验证误差
        test_task = env.unwrapped.sample_tasks(1)
        post_policy = get_policy_for_env(args.device,
                                         args.log_var_init,
                                         env,
                                         hidden_sizes=config['hidden-sizes'])
        post_policy.load_state_dict(prior_policy.state_dict())

        # based on the prior_policy, train post_policy; then test learned post_policy
        test_loss_per_task, test_avg_reward, test_last_reward = run_test(
            task=test_task,
            prior_policy=prior_policy,
            post_policy=post_policy,
            baseline=baseline,
            args=args,
            env_name=env_name,
            env_kwargs=env_kwargs,
            batch_size=batch_size,
            observation_space=observation_space,
            action_space=action_space,
            n_train_tasks=num_tasks)

        Info_test_loss.append(test_loss_per_task)
        Info_test_avg_reward.append(test_avg_reward)
        Info_test_last_reward.append(test_last_reward)