示例#1
0
def training_workflow(config, reporter):
    from gym.spaces import Box
    import numpy as np
    env_maker = get_env_maker(GazeboEnv)
    env, agents = env_maker(config['env_config'], return_agents=True)
    space = Box(low=-np.ones(2), high=np.ones(2))
    # pdb.set_trace()
    replay_buffers = {
        agent_id: ReplayBuffer(config.get('buffer_size', 1000))
        for agent_id in agents
    }
    policy = {
        k: (RandomPolicy, a.observation_space, a.action_space, {})
        for k, a in agents.items()
    }
    worker = RolloutWorker(lambda x: env,
                           policy=policy,
                           batch_steps=32,
                           policy_mapping_fn=lambda x: x,
                           episode_horizon=20)
    for i in range(config['num_iters']):
        T1 = SampleBatch.concat_samples([worker.sample()])
        for agent_id, batch in T1.policy_batches.items():
            for row in batch.rows():
                replay_buffers[agent_id].add(row['obs'],
                                             row['actions'],
                                             row['rewards'],
                                             row['new_obs'],
                                             row['dones'],
                                             weight=None)
    pdb.set_trace()
示例#2
0
文件: runner.py 项目: aclyde11/RLDock
    get_dock_marks = []

    workers = RolloutWorker(env_creator,
                            ppo.PPOTFPolicy,
                            env_config=envconf,
                            policy_config=d)
    with open(checkpoint, 'rb') as c:
        c = c.read()
        c = pickle.loads(c)
        print(list(c.keys()))
        workers.restore(c['worker'])
    fp_path = "/Users/austin/PycharmProjects/RLDock/"
    with open("log.pml", 'w') as fp:
        with open("test.pml", 'w') as f:
            for j in range(1):
                rs = workers.sample()
                print(rs)
                print(list(rs.keys()))
                ls = rs['actions'].shape[0]
                for i in range(ls):
                    i += j * ls
                    ligand_pdb = rs['infos'][i - j * ls]['atom']
                    protein_pdb_link = rs['infos'][i - j * ls]['protein']

                    with open(fp_path + 'pdbs_traj/test' + str(i) + '.pdb',
                              'w') as f:
                        f.write(ligand_pdb)
                    shutil.copyfile(
                        protein_pdb_link,
                        fp_path + 'pdbs_traj/test_p' + str(i) + '.pdb')