Esempio n. 1
0
    def test_rollout_gen(self):
        env = gym.make('SpaceInvaders-v4')
        models = config.basepath() / 'SpaceInvaders-v4' / 'models'
        visualsfile = models / 'GM53H301W5YS38XH'
        visuals = Storeable.load(str(visualsfile)).to(config.device())
        controllerfile = models / 'best_model68'
        controller = torch.load(str(controllerfile))
        policy = VCPolicy(visuals, controller)

        viewer = ImageViewer('screen', (420, 360), 'numpyRGB')

        for screen, observation, reward, done, info, action in RolloutGen(
                env, policy):
            viewer.update(screen)
Esempio n. 2
0
select = tf.SelectChannels([3, 4, 5, 6])

observe = tf.ViewChannels('transform', (320, 480), channels=[0, 1, 2])

segmentor = TVT.Compose([
    shots, player, cut_player, invader, cut_invader, barrier, select,
    TVT.ToTensor(),
    tf.CoordConv()
])

env = gym.make('SpaceInvaders-v4')
env = gym_wrappers.StepReward(env, step_reward=1)

policy_nets = []

cma_file = config.basepath(
) / 'SpaceInvaders-v4' / 'policy_runs' / '603' / 'cma_8'

if cma_file.exists():
    with Path(cma_file).open('rb') as f:
        cma = pickle.load(f)

else:
    cma = CMA()

episode_steps = []
sample_size = 1000
epochs = 200
rollouts = 1

z_size = 32