Esempio n. 1
0
def run_task(snapshot_config, *_):
    """Run task."""
    with LocalTFRunner(snapshot_config=snapshot_config) as runner:
        # env = TfEnv(normalize(MassSpringEnv_OptK_HwAsAction(params), normalize_action=False, normalize_obs=False, normalize_reward=True, reward_alpha=0.1))
        env = TfEnv(MassSpringEnv_OptK_HwAsAction(params))

        zip_project(log_dir=runner._snapshotter._snapshot_dir)

        comp_policy_model = MLPModel(
            output_dim=1,
            hidden_sizes=params.comp_policy_network_size,
            hidden_nonlinearity=tf.nn.tanh,
            output_nonlinearity=tf.nn.tanh,
        )

        mech_policy_model = MechPolicyModel_OptK_HwAsAction(params)

        policy = CompMechPolicy_OptK_HwAsAction(
            name='comp_mech_policy',
            env_spec=env.spec,
            comp_policy_model=comp_policy_model,
            mech_policy_model=mech_policy_model)

        # baseline = GaussianMLPBaseline(
        #     env_spec=env.spec,
        #     regressor_args=dict(
        #         hidden_sizes=params.baseline_network_size,
        #         hidden_nonlinearity=tf.nn.tanh,
        #         use_trust_region=True,
        #     ),
        # )

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = CMAES(env_spec=env.spec,
                     policy=policy,
                     baseline=baseline,
                     **params.cmaes_algo_kwargs)

        runner.setup(algo, env)

        runner.train(**params.cmaes_train_kwargs)
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed',
                        default=int(now.timestamp()),
                        type=int,
                        help='seed')
    parser.add_argument('--exp_id',
                        default=now.strftime("%Y_%m_%d_%H_%M_%S"),
                        help='experiment id (suffix to data directory name)')

    args = parser.parse_args()

    exp_prefix = 'cmaes_ppo_opt_k_{0}_{1}_params/seed_{2}'.format(
        args.exp_id, params.n_springs, args.seed)

    # CMA-ES global optimization
    options = params.cmaes_options
    options['seed'] = args.seed
    options['verb_filenameprefix'] = os.path.join(os.environ['PROJECTDIR'],
                                                  'data/local',
                                                  exp_prefix.replace('_',
                                                                     '-'), '-')
    x0 = params.cmaes_x0
    sigma0 = params.cmaes_sigma0

    es = cma.CMAEvolutionStrategy(x0, sigma0, options)
    es.optimize(cmaes_obj_fcn, args=[exp_prefix])
    es.result_pretty()

    zip_project(log_dir=os.path.join(os.environ['PROJECTDIR'], 'data/local',
                                     exp_prefix.replace('_', '-')))