示例#1
0
def main():
    parser = base_config.get_base_config()
    parser = rs_config.get_rs_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='mbrl-rs' + args.task,
                                time_str=args.exp_id)

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import shooting_trainer
    from mbbl.sampler import singletask_sampler
    from mbbl.worker import rs_worker
    from mbbl.network.policy.random_policy import policy_network

    if args.gt_dynamics:
        from mbbl.network.dynamics.groundtruth_forward_dynamics import \
            dynamics_network
    else:
        from mbbl.network.dynamics.deterministic_forward_dynamics import \
            dynamics_network

    if args.gt_reward:
        from mbbl.network.reward.groundtruth_reward import reward_network
    else:
        from mbbl.network.reward.deterministic_reward import reward_network

    train(shooting_trainer, singletask_sampler, rs_worker, dynamics_network,
          policy_network, reward_network, args)
def main():
    parser = base_config.get_base_config()
    parser = cem_config.get_cem_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        if args.output_dir == None:
            args.output_dir = "log"
        else:
            pass
        log_path = str(args.output_dir)+'/pets-cem-' + str(args.task) + '/seed-' + str(args.seed) +'/num_planning_traj-' + str(args.num_planning_traj) +'/plannging depth-' + str(args.planning_depth) +'/timesteps_per_batch-' + str(args.timesteps_per_batch) +'/random_timesteps-' + str(args.random_timesteps) +'/max_timesteps-' + str(args.max_timesteps)+'/'
        logger.set_file_handler(path=log_path, prefix='', time_str="0")

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import shooting_trainer
    from mbbl.sampler import singletask_pets_sampler
    from mbbl.worker import cem_worker
    from mbbl.network.policy.random_policy import policy_network

    if args.gt_dynamics:
        from mbbl.network.dynamics.groundtruth_forward_dynamics import \
            dynamics_network
    else:
        from mbbl.network.dynamics.deterministic_forward_dynamics import \
            dynamics_network

    if args.gt_reward:
        from mbbl.network.reward.groundtruth_reward import reward_network
    else:
        from mbbl.network.reward.deterministic_reward import reward_network

    train(shooting_trainer, singletask_pets_sampler, cem_worker,
          dynamics_network, policy_network, reward_network, args)
示例#3
0
def main():
    parser = base_config.get_base_config()
    parser = mf_config.get_mf_config(parser)
    parser = il_config.get_il_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='gail-mfrl-mf-' + args.task,
                                time_str=args.exp_id)

    # no random policy for model-free rl
    assert args.random_timesteps == 0

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import gail_trainer
    from mbbl.sampler import singletask_sampler
    from mbbl.worker import mf_worker
    import mbbl.network.policy.trpo_policy
    import mbbl.network.policy.ppo_policy

    policy_network = {
        'ppo': mbbl.network.policy.ppo_policy.policy_network,
        'trpo': mbbl.network.policy.trpo_policy.policy_network
    }[args.trust_region_method]

    # here the dynamics and reward are simply placeholders, which cannot be
    # called to pred next state or reward
    from mbbl.network.dynamics.base_dynamics import base_dynamics_network
    from mbbl.network.reward.GAN_reward import reward_network

    train(gail_trainer, singletask_sampler, mf_worker, base_dynamics_network,
          policy_network, reward_network, args)
示例#4
0
def main():
    parser = base_config.get_base_config()
    parser = ilqr_config.get_ilqr_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='mbrl-ilqr' + args.task,
                                time_str=args.exp_id)

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import shooting_trainer
    from mbbl.sampler import singletask_ilqr_sampler
    from mbbl.worker import model_worker
    from mbbl.network.policy.random_policy import policy_network

    if args.gt_dynamics:
        from mbbl.network.dynamics.groundtruth_forward_dynamics import \
            dynamics_network
    else:
        from mbbl.network.dynamics.deterministic_forward_dynamics import \
            dynamics_network

    if args.gt_reward:
        from mbbl.network.reward.groundtruth_reward import reward_network
    else:
        from mbbl.network.reward.deterministic_reward import reward_network

    if (not args.gt_reward) or not (args.gt_dynamics):
        raise NotImplementedError('Havent finished! Oooooops')

    train(shooting_trainer, singletask_ilqr_sampler, model_worker,
          dynamics_network, policy_network, reward_network, args)
示例#5
0
def main():
    parser = base_config.get_base_config()
    parser = metrpo_config.get_metrpo_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='mbrl-metrpo-' + args.task,
                                time_str=args.exp_id)
    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import metrpo_trainer
    from mbbl.sampler import singletask_metrpo_sampler
    from mbbl.worker import metrpo_worker
    from mbbl.network.dynamics.deterministic_forward_dynamics import dynamics_network
    from mbbl.network.policy.trpo_policy import policy_network
    from mbbl.network.reward.groundtruth_reward import reward_network

    train(metrpo_trainer, singletask_metrpo_sampler, metrpo_worker,
          dynamics_network, policy_network, reward_network, args)
示例#6
0
def main():
    parser = base_config.get_base_config()
    parser = ilqr_config.get_ilqr_config(parser)
    parser = gps_config.get_gps_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='mbrl-gps-' + args.task,
                                time_str=args.exp_id)

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import gps_trainer
    from mbbl.sampler import singletask_sampler
    from mbbl.worker import mf_worker
    from mbbl.network.policy.gps_policy_gmm_refit import policy_network

    assert not args.gt_dynamics and args.gt_reward
    from mbbl.network.dynamics.linear_stochastic_forward_dynamics_gmm_prior \
        import dynamics_network
    from mbbl.network.reward.groundtruth_reward import reward_network

    train(gps_trainer, singletask_sampler, mf_worker,
          dynamics_network, policy_network, reward_network, args)
def main():
    parser = base_config.get_base_config()
    parser = rs_config.get_rs_config(parser)
    parser = il_config.get_il_config(parser)
    args = base_config.make_parser(parser)
    args = il_config.post_process_config(args)

    if args.write_log:
        args.log_path = logger.set_file_handler(path=args.output_dir,
                                                prefix='inverse_dynamics' +
                                                args.task,
                                                time_str=args.exp_id)

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))

    train(args)
示例#8
0
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))


if __name__ == '__main__':
    parser = base_config.get_base_config()
    parser = mbmf_config.get_mbmf_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        if args.output_dir == None:
            args.output_dir = "log"
        else:
            pass
        log_path = str(args.output_dir)+'/mb-mf-' + str(args.task) + '/seed-' + str(args.seed) +'/num_planning_traj-' + str(args.num_planning_traj) +'/plannging depth-' + str(args.planning_depth) +'/timesteps_per_batch-' + str(args.timesteps_per_batch) +'/random_timesteps-' + str(args.random_timesteps) +'/max_timesteps-' + str(args.max_timesteps)+'/'
        logger.set_file_handler(path=log_path, prefix='', time_str="0")

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import mbmf_trainer
    from mbbl.sampler import mbmf_sampler
    from mbbl.worker import mbmf_worker
    from mbbl.network.policy.mbmf_policy import mbmf_policy_network

    if args.gt_dynamics:
        from mbbl.network.dynamics.groundtruth_forward_dynamics import \
            dynamics_network
    else:
        from mbbl.network.dynamics.deterministic_forward_dynamics import \
            dynamics_network

    if args.gt_reward:
示例#9
0
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))


if __name__ == '__main__':
    parser = base_config.get_base_config()
    parser = mbmf_config.get_mbmf_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='mbmfrl-rs' + args.task,
                                time_str=args.exp_id)

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from mbbl.trainer import mbmf_trainer
    from mbbl.sampler import mbmf_sampler
    from mbbl.worker import mbmf_worker
    from mbbl.network.policy.mbmf_policy import mbmf_policy_network

    if args.gt_dynamics:
        from mbbl.network.dynamics.groundtruth_forward_dynamics import \
            dynamics_network
    else:
        from mbbl.network.dynamics.deterministic_forward_dynamics import \
            dynamics_network