Example #1
0
            job_name_lpgftw_seed + '/iterations/task_{}/'.format(t) +
            'baseline_0.pickle', 'rb')
        baseline_mtl[t] = pickle.load(f)
        f.close()

        if isinstance(policy_mtl.model.theta, list):
            policy_mtl.model.theta = torch.autograd.Variable(torch.zeros(0))

        agent_mtl = NPGFTW(e,
                           policy_mtl,
                           baseline_mtl,
                           normalized_step_size=0.1,
                           seed=SEED,
                           save_logs=False,
                           new_col_mode='performance')

        mean_test_perf = agent_mtl.test_tasks(test_rollouts=10,
                                              num_cpu=num_cpu,
                                              task_ids=np.array([t]))

        forward_transfer_results = {
            **forward_transfer_results,
            **mean_test_perf
        }

    result_file = open(job_name_lpgftw_seed + '/start_results.txt', 'w')
    result_file.write(str(forward_transfer_results))
    result_file.close()

    SEED += 10
    f = open(job_name_lpgftw_seed + '/trained_mtl_baseline.pickle', 'wb')
    pickle.dump(baseline_mtl, f)
    f.close()
    f = open(job_name_lpgftw_seed + '/trained_mtl_alphas.pickle', 'wb')
    pickle.dump(agent_mtl.theta, f)
    f.close()
    f = open(job_name_lpgftw_seed + '/trained_mtl_grads.pickle', 'wb')
    pickle.dump(agent_mtl.grad, f)
    f.close()
    f = open(job_name_lpgftw_seed + '/trained_mtl_hess.pickle', 'wb')
    pickle.dump(agent_mtl.hess, f)
    f.close()
    f = open(job_name_lpgftw_seed + '/task_order.pickle', 'wb')
    pickle.dump(task_order, f)
    f.close()

    make_multitask_train_plots(loggers=agent_mtl.logger,
                               keys=['stoc_pol_mean'],
                               save_loc=job_name_lpgftw_seed + '/logs/')

    mean_test_perf = agent_mtl.test_tasks(test_rollouts=10, num_cpu=num_cpu)
    result = np.mean(list(mean_test_perf.values()))
    print(result)
    make_multitask_test_plots(mean_test_perf,
                              save_loc=job_name_lpgftw_seed + '/')
    result_file = open(job_name_lpgftw_seed + '/results.txt', 'w')
    result_file.write(str(mean_test_perf))
    result_file.close()

    SEED += 10