예제 #1
0
 def _get_shape_sprites(self):
     shapes = AttrDict()
     canvas = np.zeros((self._sprite_res, self._sprite_res), np.uint8)
     shapes.rectangle = cv2.rectangle(
         canvas.copy(), (1, 1),
         (self._sprite_res - 2, self._sprite_res - 2), 255, -1)
     shapes.circle = cv2.circle(
         canvas.copy(),
         (int(self._sprite_res / 2), int(self._sprite_res / 2)),
         int(self._sprite_res / 3), 255, -1)
     shapes.tri_right = cv2.fillConvexPoly(
         canvas.copy(),
         np.array([[[1, 1], [1, self._sprite_res - 2],
                    [self._sprite_res - 2,
                     int(self._sprite_res / 2)]]]), 255)
     shapes.tri_bottom = cv2.fillConvexPoly(
         canvas.copy(),
         np.array([[[1, 1], [self._sprite_res - 2, 1],
                    [int(self._sprite_res / 2), self._sprite_res - 2]]]),
         255)
     shapes.tri_left = cv2.fillConvexPoly(
         canvas.copy(),
         np.array([[[self._sprite_res - 2, 1],
                    [self._sprite_res - 2, self._sprite_res - 2],
                    [1, int(self._sprite_res / 2)]]]), 255)
     shapes.tri_top = cv2.fillConvexPoly(
         canvas.copy(),
         np.array([[[1, self._sprite_res - 2],
                    [self._sprite_res - 2, self._sprite_res - 2],
                    [int(self._sprite_res / 2), 1]]]), 255)
     return shapes
예제 #2
0
    def __getitem__(self, item):
        traj = self._generator.gen_trajectory()

        data_dict = AttrDict()
        data_dict.images = traj.images[:, None].repeat(3, axis=1).astype(
            np.float32) / (255. / 2) - 1.0
        data_dict.states = traj.states
        data_dict.shape_idxs = traj.shape_idxs
        data_dict.rewards = traj.rewards

        return data_dict
예제 #3
0
    def gen_trajectory(self):
        """Samples trajectory with bouncing sprites."""
        output = AttrDict()

        # sample coordinate trajectories [T, n_shapes, state_dim]
        output.states = self._traj_gen.create(self._spec.max_seq_len,
                                              self._spec.shapes_per_traj)

        # sample shapes for trajectory
        output.shape_idxs = self._sample_shapes()
        shapes = np.asarray(self.SHAPES)[output.shape_idxs]

        # render images for trajectories + shapes
        output.images = self._render(output.states, shapes)

        # compute rewards for trajectories
        output.rewards = self._reward(output.states, shapes)

        return output
예제 #4
0
    # # complete encoder spec
    # spec = AttrDict(
    #     resolution=64,
    #     max_seq_len=30,
    #     max_speed=0.05,
    #     obj_size=0.2,
    #     shapes_per_traj=3,
    #     rewards=[ZeroReward, VertPosReward, HorPosReward, AgentXReward, AgentYReward, TargetXReward, TargetYReward],
    # )

    # decoder training spec
    spec = AttrDict(
        resolution=64,
        max_seq_len=30,
        max_speed=0.05,
        obj_size=0.2,
        shapes_per_traj=1,
        rewards=[VertPosReward, HorPosReward],
    )

    # Generate a small encoder dataset for overfitting test
    train_loader, val_loader, test_loader = loadVAEData(
        spec,
        path='./data/data_decoder',
        decoder=True,
        new=True,
        train_num=1000,
        val_num=200,
        test_num=200)
예제 #5
0
def main(DEBUG=False, OVERFIT=False):
    args = get_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dtype = torch.float32

    rewards = {
        'rewards_class': [
            ZeroReward, VertPosReward, HorPosReward, AgentXReward,
            AgentYReward, TargetXReward, TargetYReward
        ],
        'rewards_name': [
            'zero', 'vertical_position', 'horizontal_position', 'agent_x',
            'agent_y', 'target_x', 'target_y'
        ]
    }

    reward_classes = np.array(rewards['rewards_class'])[args.reward_indices] \
        if args.reward_indices else np.array(rewards['rewards_class'])
    reward_names = np.array(rewards['rewards_name'])[args.reward_indices] \
        if args.reward_indices else None
    num_reward_heads = len(
        args.reward_indices) if args.reward_indices else None

    spec = AttrDict(resolution=args.resolution,
                    max_seq_len=args.max_seq_len,
                    max_speed=args.max_speed,
                    obj_size=args.obj_size,
                    shapes_per_traj=args.shapes_per_traj,
                    rewards=reward_classes)

    # train_loader, val_loader, test_loader = loadVAEData(spec, batch_size=args.batch_size)
    train_loader, val_loader, test_loader = loadVAEData(
        spec, save_to_disk=True, batch_size=args.batch_size, OVERFIT=OVERFIT)

    if args.reconstruction:
        model = VAEReconstructionModel()
    else:
        model = VAERewardPredictionModel(num_reward_heads, args.train_decoder)

    if args.load_model:
        model.load_state_dict(torch.load(args.model_path))
        if args.reconstruction:
            print('VAEReconstruction model loaded\n')
        else:
            print('VAERewardPrediction model loaded\n')

    if not DEBUG:
        wandb.init(project="impl_jiefan")
        wandb.config.update(args)
        wandb.watch(model)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.reconstruction:
        trainVAEReconstruction(model,
                               optimizer,
                               train_loader,
                               val_loader,
                               args.model_path,
                               device,
                               dtype,
                               epochs=args.epochs,
                               DEBUG=DEBUG)
    else:
        trainVAERewardPrediction(model,
                                 reward_names,
                                 args.train_decoder,
                                 optimizer,
                                 train_loader,
                                 val_loader,
                                 args.model_path,
                                 device,
                                 dtype,
                                 epochs=args.epochs,
                                 DEBUG=DEBUG)
예제 #6
0
    def step(self, action):
        _, reward, done, info = super().step(action)
        return self._state[:, :self._n_dim].copy().flatten(
        ), reward, done, info


class SpritesRepelEnv(SpritesEnv):
    def __init__(self, **kwarg):
        super().__init__(follow=False, **kwarg)


class SpritesRepelStateEnv(SpritesStateEnv):
    def __init__(self, **kwarg):
        super().__init__(follow=False, **kwarg)


if __name__ == '__main__':
    data_spec = AttrDict(
        resolution=64,
        max_ep_len=40,
        max_speed=0.05,  # total image range [0, 1]
        obj_size=0.2,  # size of objects, full images is 1.0
        follow=True,
    )
    env = SpritesEnv()
    env.set_config(data_spec)
    obs = env.reset()
    cv2.imwrite("test_rl.png", 255 * np.expand_dims(obs, -1))
    obs, reward, done, info = env.step([0, 0])
    cv2.imwrite("test_rl_1.png", 255 * np.expand_dims(obs, -1))