def test_rollout_gen(self): env = gym.make('SpaceInvaders-v4') models = config.basepath() / 'SpaceInvaders-v4' / 'models' visualsfile = models / 'GM53H301W5YS38XH' visuals = Storeable.load(str(visualsfile)).to(config.device()) controllerfile = models / 'best_model68' controller = torch.load(str(controllerfile)) policy = VCPolicy(visuals, controller) viewer = ImageViewer('screen', (420, 360), 'numpyRGB') for screen, observation, reward, done, info, action in RolloutGen( env, policy): viewer.update(screen)
select = tf.SelectChannels([3, 4, 5, 6]) observe = tf.ViewChannels('transform', (320, 480), channels=[0, 1, 2]) segmentor = TVT.Compose([ shots, player, cut_player, invader, cut_invader, barrier, select, TVT.ToTensor(), tf.CoordConv() ]) env = gym.make('SpaceInvaders-v4') env = gym_wrappers.StepReward(env, step_reward=1) policy_nets = [] cma_file = config.basepath( ) / 'SpaceInvaders-v4' / 'policy_runs' / '603' / 'cma_8' if cma_file.exists(): with Path(cma_file).open('rb') as f: cma = pickle.load(f) else: cma = CMA() episode_steps = [] sample_size = 1000 epochs = 200 rollouts = 1 z_size = 32