Пример #1
0
def renderers():
    return {
        'image':
        spriteworld_renderers.PILRenderer(
            image_size=(64, 64),
            anti_aliasing=5,
            color_to_rgb=spriteworld_renderers.color_maps.hsv_to_rgb,
        )
    }
Пример #2
0
def get_config(mode=None):
    """Generate environment config.

    Args:
        mode: Unused task mode.

    Returns:
        config: Dictionary defining task/environment configuration. Can be fed
            as kwargs to physics_environment.PhysicsEnvironment.
    """

    # Factor distributions for the sprites.
    factors = distribs.Product([
        distribs.Continuous('x', 0.1, 0.9),
        distribs.Continuous('y', 0.1, 0.9),
        distribs.Discrete('shape', ['circle', 'square', 'triangle']),
        distribs.Discrete('scale', [0.1]),
        distribs.Continuous('c0', 0, 1),
        distribs.Continuous('c1', 0.5, 1.),
        distribs.Discrete('c2', [1.]),
        distribs.Continuous('x_vel', -0.03, 0.03),
        distribs.Continuous('y_vel', -0.03, 0.03),
        distribs.Continuous('mass', 0.5, 2.0),
    ])

    sprite_gen = generate_sprites.generate_sprites(
        factors, num_sprites=lambda: np.random.randint(4, 8))

    # The collisions are simulated by applying an invisible rigid circular shell
    # around each sprite. The shell_radius of 0.08 is eye-balled to look
    # reasonable.
    force = forces.SymmetricShellCollision(shell_radius=0.08)
    graph_generator = graph_generators.LowerTriangular(force=force)

    renderers = {
        'image':
            spriteworld_renderers.PILRenderer(
                image_size=(64, 64), anti_aliasing=5)
    }

    config = {
        'graph_generators': (graph_generator,),
        'renderers': renderers,
        'init_sprites': sprite_gen,
        'episode_length': 30,
        'bounce_off_walls': True,
        'physics_steps_per_env_step': 10,
        'metadata': {
            'name': os.path.basename(__file__),
            'mode': mode
        }
    }
    return config
Пример #3
0
def get_config(mode='train'):
    """Generate environment config.

    Args:
        mode: Unused task mode.

    Returns:
        config: Dictionary defining task/environment configuration. Can be fed
            as kwargs to physics_environment.PhysicsEnvironment.
    """

    # Factor distributions for the sprites.
    factors = distribs.Product([
        distribs.Continuous('x', 0.1, 0.9),
        distribs.Continuous('y', 0.1, 0.9),
        distribs.Discrete('shape', ['circle']),
        distribs.Discrete('scale', [0.1]),
        distribs.Continuous('c0', 0, 1),
        distribs.Discrete('c1', [1.]),
        distribs.Discrete('c2', [1.]),
        distribs.Continuous('x_vel', -0.03, 0.03),
        distribs.Continuous('y_vel', -0.03, 0.03),
        distribs.Discrete('mass', [1]),
    ])

    num_sprites = _NUM_SPRITES[mode]
    sprite_gen = generate_sprites.generate_sprites(
        factors, num_sprites=num_sprites)

    graph_generator = graph_generators.FullyConnected(force=forces.NoForce)

    renderers = {
        'image':
            spriteworld_renderers.PILRenderer(
                image_size=(64, 64), anti_aliasing=5)
    }

    config = {
        'graph_generators': (graph_generator,),
        'renderers': renderers,
        'init_sprites': sprite_gen,
        'episode_length': 20,
        'bounce_off_walls': False,
        'metadata': {
            'name': os.path.basename(__file__),
            'mode': mode
        }
    }
    return config
Пример #4
0
def get_config(mode=None):
    """Generate environment config.

    Args:
        mode: Unused task mode.

    Returns:
        config: Dictionary defining task/environment configuration. Can be fed
            as kwargs to physics_environment.PhysicsEnvironment.
    """

    # Factor distributions for the sprites.
    factors = distribs.Product([
        distribs.Continuous('x', 0.1, 0.9),
        distribs.Continuous('y', 0.1, 0.9),
        distribs.Discrete('shape', ['circle', 'square', 'triangle']),
        distribs.Discrete('scale', [0.1]),
        distribs.Continuous('c0', 0, 1),
        distribs.Continuous('c1', 0.5, 1.),
        distribs.Discrete('c2', [1.]),
        distribs.Continuous('x_vel', -0.03, 0.03),
        distribs.Continuous('y_vel', -0.03, 0.03),
        distribs.Continuous('mass', 0.5, 2.0),
    ])

    sprite_gen = generate_sprites.generate_sprites(factors, num_sprites=4)

    force = forces.Spring(spring_constant=0.03, spring_equilibrium=0.25)
    graph_generator = graph_generators.FullyConnected(force=force)

    renderers = {
        'image':
        spriteworld_renderers.PILRenderer(image_size=(64, 64), anti_aliasing=5)
    }

    config = {
        'graph_generators': (graph_generator, ),
        'renderers': renderers,
        'init_sprites': sprite_gen,
        'episode_length': 30,
        'bounce_off_walls': False,
        'physics_steps_per_env_step': 10,
        'metadata': {
            'name': os.path.basename(__file__),
            'mode': mode
        }
    }
    return config
def main(_):
    logging.info('Generating gif for config {}'.format(FLAGS.config))

    gif_path = os.path.join(FLAGS.gif_path_head, FLAGS.gif_path_tail)
    if gif_path[0] == '~':
        gif_path = os.path.join(os.path.expanduser('~'), gif_path[2:])
    if os.path.isfile(gif_path):
        should_continue = input('Path {} to write gif to already exists. '
                                'Overwrite the existing file? (y/n)')
        if should_continue != 'y':
            logging.info('You pressed {}, not "y", so terminating '
                         'program.'.format(should_continue))
            return
        else:
            logging.info('You pressed "y". Overwriting existing file.')

    # Load and adjust environment config
    config = importlib.import_module(FLAGS.config)
    config = config.get_config(FLAGS.mode)
    config['renderers'] = {
        'image':
        renderers.PILRenderer(image_size=(FLAGS.render_size,
                                          FLAGS.render_size),
                              color_to_rgb=renderers.color_maps.hsv_to_rgb
                              if FLAGS.hsv_colors else None,
                              anti_aliasing=FLAGS.anti_aliasing),
    }
    env = physics_environment.PhysicsEnvironment(**config)

    # Run the environment in a loop
    duration_per_frame = 1. / FLAGS.fps
    timestep = env.reset()
    images = []
    episodes_generated = 0
    while episodes_generated < FLAGS.num_episodes:
        images.append(timestep.observation['image'])
        if timestep.last():
            episodes_generated += 1
            logging.info('Generated {} of {} episodes'.format(
                episodes_generated, FLAGS.num_episodes))
        timestep = env.step()

    logging.info('Writing gif to file {}'.format(gif_path))
    imageio.mimsave(gif_path, images, duration=duration_per_frame)
Пример #6
0
def setup_run_ui(env_config, render_size, task_hsv_colors, anti_aliasing):
    """Start a Demo UI given an env_config."""
    if isinstance(env_config['action_space'], action_spaces.SelectMove):
        # DragAndDrop is a bit easier to demo than the SelectMove action space
        #env_config['action_space'] = action_spaces.DragAndDrop(scale=0.5, noise_scale=np.array([0,0,0.025,0.025]), proportional_motion_noise=0.35)
        env_config['action_space'] = action_spaces.DragAndDrop(
            scale=0.5,
            noise_scale=0.02,
            proportional_motion_noise=0.35,
            filter_distribs=env_config['metadata']['filter_distribs'])
        agent = HumanDragAndDropAgent(env_config['action_space'])
    elif isinstance(env_config['action_space'], action_spaces.Embodied):
        agent = HumanEmbodiedAgent(env_config['action_space'])
    else:
        raise ValueError(
            'Demo is not configured to run with action space {}.'.format(
                env_config['action_space']))
    env_config['renderers'] = {
        'image':
        renderers.PILRenderer(image_size=(render_size, render_size),
                              color_to_rgb=renderers.color_maps.hsv_to_rgb
                              if task_hsv_colors else None,
                              anti_aliasing=anti_aliasing),
        'success':
        renderers.Success()
    }
    env = environment.Environment(**env_config)
    ui = MatplotlibUI()
    agent.register_callbacks(ui)

    # Start RL loop
    timestep = env.reset()
    ui.update(timestep, action=None)

    while True:
        action = agent.step(timestep)
        timestep = env.step(action)
        if isinstance(env_config['action_space'], action_spaces.DragAndDrop):
            ui.update(timestep, action)
        else:
            ui.update(timestep, None)
Пример #7
0
def main(_):
    config = importlib.import_module(FLAGS.config)
    config = config.get_config(FLAGS.mode)
    if isinstance(config['action_space'], action_spaces.SelectMove):
        # DragAndDrop is a bit easier to demo than the SelectMove action space
        config['action_space'] = action_spaces.DragAndDrop(scale=0.5)
        agent = HumanDragAndDropAgent(config['action_space'])
    elif isinstance(config['action_space'], action_spaces.Embodied):
        agent = HumanEmbodiedAgent(config['action_space'])
    else:
        raise ValueError(
            'Demo is not configured to run with action space {}.'.format(
                config['action_space']))
    config['renderers'] = {
        'image':
        renderers.PILRenderer(image_size=(FLAGS.render_size,
                                          FLAGS.render_size),
                              color_to_rgb=color_maps.hsv_to_rgb
                              if FLAGS.task_hsv_colors else None,
                              anti_aliasing=FLAGS.anti_aliasing),
        'success':
        renderers.Success()
    }
    env = environment.Environment(**config)
    demo = DemoUI()

    for event_name, callback in agent.callbacks().items():
        demo.register_callback(event_name, callback)

    timestep = env.reset()
    demo.update(timestep, action=None)

    while True:
        action = agent.step(timestep)
        timestep = env.step(action)
        if isinstance(config['action_space'], action_spaces.DragAndDrop):
            demo.update(timestep, action)
        else:
            demo.update(timestep, None)
Пример #8
0
def setup_run_ui(env_config, render_size, task_hsv_colors, anti_aliasing):
  """Start a Demo UI given an env_config."""
  if isinstance(env_config['action_space'], action_spaces.SelectMove):
    # DragAndDrop is a bit easier to demo than the SelectMove action space
    env_config['action_space'] = action_spaces.DragAndDrop(scale=0.5)
    agent = HumanDragAndDropAgent(env_config['action_space'])
  elif isinstance(env_config['action_space'], action_spaces.Embodied):
    agent = HumanEmbodiedAgent(env_config['action_space'])
  else:
    raise ValueError(
        'Demo is not configured to run with action space {}.'.format(
            env_config['action_space']))
  env_config['renderers'] = {
      'image':
          renderers.PILRenderer(
              image_size=(render_size, render_size),
              color_to_rgb=renderers.color_maps.hsv_to_rgb
              if task_hsv_colors else None,
              anti_aliasing=anti_aliasing),
      'success':
          renderers.Success()
  }
  env = environment.Environment(**env_config)
  demo = DemoUI()

  for event_name, callback in agent.callbacks().items():
    demo.register_callback(event_name, callback)

  # Start RL loop
  timestep = env.reset()
  demo.update(timestep, action=None)

  while True:
    action = agent.step(timestep)
    timestep = env.step(action)
    if isinstance(env_config['action_space'], action_spaces.DragAndDrop):
      demo.update(timestep, action)
    else:
      demo.update(timestep, None)
def get_config(mode=None):
    """Generate environment config.

  Args:
    mode: Unused task mode.

  Returns:
    config: Dictionary defining task/environment configuration. Can be fed as
      kwargs to environment.Environment.
  """
    del mode

    shared_factors = distribs.Product([
        distribs.Continuous('x', 0.1, 0.9),
        distribs.Continuous('y', 0.1, 0.9),
        distribs.Discrete('shape', ['square', 'triangle', 'circle']),
        distribs.Discrete('scale', [0.13]),
        distribs.Continuous('c1', 0.3, 1.),
        distribs.Continuous('c2', 0.9, 1.),
    ])
    target_hue = distribs.Continuous('c0', 0., 0.4)
    distractor_hue = distribs.Continuous('c0', 0.5, 0.9)
    target_factors = distribs.Product([
        target_hue,
        shared_factors,
    ])
    distractor_factors = distribs.Product([
        distractor_hue,
        shared_factors,
    ])

    target_sprite_gen = sprite_generators.generate_sprites(
        target_factors, num_sprites=NUM_TARGETS)
    distractor_sprite_gen = sprite_generators.generate_sprites(
        distractor_factors, num_sprites=NUM_DISTRACTORS)
    sprite_gen = sprite_generators.chain_generators(target_sprite_gen,
                                                    distractor_sprite_gen)
    # Randomize sprite ordering to eliminate any task information from occlusions
    sprite_gen = sprite_generators.shuffle(sprite_gen)

    # Create the agent body
    agent_body_factors = distribs.Product([
        distribs.Continuous('x', 0.1, 0.9),
        distribs.Continuous('y', 0.1, 0.9),
        distribs.Discrete('shape', ['circle']),
        distribs.Discrete('scale', [0.07]),
        distribs.Discrete('c0', [1.]),
        distribs.Discrete('c1', [0.]),
        distribs.Discrete('c2', [1.]),
    ])
    agent_body_gen = sprite_generators.generate_sprites(agent_body_factors,
                                                        num_sprites=1)
    sprite_gen = sprite_generators.chain_generators(sprite_gen, agent_body_gen)

    task = tasks.FindGoalPosition(filter_distrib=target_hue,
                                  terminate_distance=TERMINATE_DISTANCE)

    renderers = {
        'image':
        spriteworld_renderers.PILRenderer(image_size=(64, 64),
                                          anti_aliasing=5,
                                          color_to_rgb=color_maps.hsv_to_rgb)
    }

    config = {
        'task': task,
        'action_space': action_spaces.Embodied(step_size=0.05),
        'renderers': renderers,
        'init_sprites': sprite_gen,
        'max_episode_length': 50,
        'metadata': {
            'name': os.path.basename(__file__),
            'mode': mode
        }
    }
    return config
Пример #10
0
    def __init__(self,
                 n=3,
                 r=1.,
                 m=1.,
                 hw=10,
                 granularity=5,
                 res=32,
                 t=1.,
                 init_v_factor=0,
                 friction_coefficient=0.,
                 seed=None,
                 sprites=False):
        """Initialize a physics env with some general parameters.

        Args:
            n (int): Optional, number of objects in the scene.
            r (float)/list(float): Optional, radius of objects in the scene.
            m (float)/list(float): Optional, mass of the objects in the scene.
            hw (float): Optional, coordinate limits of the environment.
            eps (float): Optional, internal simulation granularity as the
                fraction of one time step. Does not change speed of simulation.
            res (int): Optional, pixel resolution of the images.
            t (float): Optional, dt of the step() method. Speeds up or slows
                down the simulation.
            init_v_factor (float): Scaling factor for inital velocity. Used only
                in Gravity Environment.
            friction_coefficient (float): Friction slows down balls.
            seed (int): Set random seed for reproducibility.
            sprites (bool): Render selection of sprites using spriteworld
                instead of balls.

        """
        np.random.seed(seed)

        self.n = n
        self.r = np.array([[r]] * n) if np.isscalar(r) else r
        self.m = np.array([[m]] * n) if np.isscalar(m) else m
        self.hw = hw
        self.internal_steps = granularity
        self.eps = 1 / granularity
        self.res = res
        self.t = t

        self.x = self.init_x()
        self.v = self.init_v(init_v_factor)
        self.a = np.zeros_like(self.v)

        self.fric_coeff = friction_coefficient
        self.v_rotation_angle = 2 * np.pi * 0.05

        if n > 3:
            self.use_colors = True
        else:
            self.use_colors = False

        if sprites:
            self.renderer = spriteworld_renderers.PILRenderer(
                image_size=(self.res, self.res),
                anti_aliasing=10,
            )

            shapes = ['triangle', 'square', 'circle', 'star_4']

            if not np.isscalar(r):
                print("Scale elements according to radius of first element.")

            # empirical scaling rule, works for r = 1.2 and 2
            self.scale = self.r[0] / self.hw / 0.6
            self.shapes = np.random.choice(shapes, 3)
            self.draw_image = self.draw_sprites

        else:
            self.draw_image = self.draw_balls
Пример #11
0
def get_config(mode='train'):
    """Generate environment config.

  Args:
    mode: Unused task mode.

  Returns:
    config: Dictionary defining task/environment configuration. Can be fed as
      kwargs to environment.Environment.
  """

    # Factor distributions common to all objects.
    common_factors = distribs.Product([
        distribs.Continuous('x', 0.1, 0.9),
        distribs.Continuous('y', 0.1, 0.9),
        distribs.Continuous('angle', 0, 360, dtype='int32'),
    ])

    # train/test split for goal-finding object scales and clustering object colors
    goal_finding_scale_test = distribs.Continuous('scale', 0.08, 0.12)
    green_blue_colors = distribs.Product([
        distribs.Continuous('c1', 64, 256, dtype='int32'),
        distribs.Continuous('c2', 64, 256, dtype='int32'),
    ])
    if mode == 'train':
        goal_finding_scale = distribs.SetMinus(
            distribs.Continuous('scale', 0.05, 0.15),
            goal_finding_scale_test,
        )
        cluster_colors = distribs.Product([
            distribs.Continuous('c0', 128, 256, dtype='int32'),
            green_blue_colors
        ])
    elif mode == 'test':
        goal_finding_scale = goal_finding_scale_test
        cluster_colors = distribs.Product([
            distribs.Continuous('c0', 0, 128, dtype='int32'), green_blue_colors
        ])
    else:
        raise ValueError(
            'Invalid mode {}. Mode must be "train" or "test".'.format(mode))

    # Create clustering sprite generators
    sprite_gen_list = []
    cluster_shapes = [
        distribs.Discrete('shape', [s])
        for s in ['triangle', 'square', 'pentagon']
    ]
    for shape in cluster_shapes:
        factors = distribs.Product([
            common_factors,
            cluster_colors,
            shape,
            distribs.Continuous('scale', 0.08, 0.12),
        ])
        sprite_gen_list.append(
            sprite_generators.generate_sprites(factors, num_sprites=2))

    # Create goal-finding sprite generators
    goal_finding_colors = [
        distribs.Product([
            distribs.Continuous('c0', 192, 256, dtype='int32'),
            distribs.Continuous('c1', 0, 128, dtype='int32'),
            distribs.Continuous('c2', 64, 128, dtype='int32'),
        ]),
        distribs.Product([
            distribs.Continuous('c0', 0, 128, dtype='int32'),
            distribs.Continuous('c1', 192, 256, dtype='int32'),
            distribs.Continuous('c2', 64, 128, dtype='int32'),
        ])
    ]
    # Goal positions corresponding to the colors in goal_finding_colors
    goal_finding_positions = [(0., 0.5), (1., 0.5)]
    goal_finding_shapes = distribs.Discrete('shape', ['spoke_4', 'star_4'])
    for colors in goal_finding_colors:
        factors = distribs.Product([
            common_factors,
            goal_finding_scale,
            goal_finding_shapes,
            colors,
        ])
        sprite_gen_list.append(
            sprite_generators.generate_sprites(
                factors, num_sprites=lambda: np.random.randint(1, 3)))

    # Create distractor sprite generator
    distractor_factors = distribs.Product([
        common_factors,
        distribs.Discrete('shape', ['circle']),
        distribs.Continuous('c0', 64, 256, dtype='uint8'),
        distribs.Continuous('c1', 64, 256, dtype='uint8'),
        distribs.Continuous('c2', 64, 256, dtype='uint8'),
        distribs.Continuous('scale', 0.08, 0.12),
    ])
    sprite_gen_list.append(
        sprite_generators.generate_sprites(
            distractor_factors, num_sprites=lambda: np.random.randint(0, 3)))

    # Concat clusters into single scene to generate
    sprite_gen = sprite_generators.chain_generators(*sprite_gen_list)
    # Randomize sprite ordering to eliminate any task information from occlusions
    sprite_gen = sprite_generators.shuffle(sprite_gen)

    # Create the combined task of goal-finding and clustering
    task_list = []
    task_list.append(
        tasks.Clustering(cluster_shapes, terminate_bonus=0., reward_range=10.))
    for colors, goal_pos in zip(goal_finding_colors, goal_finding_positions):
        goal_finding_task = tasks.FindGoalPosition(distribs.Product(
            [colors, goal_finding_shapes]),
                                                   goal_position=goal_pos,
                                                   weights_dimensions=(1, 0),
                                                   terminate_distance=0.15,
                                                   raw_reward_multiplier=30)
        task_list.append(goal_finding_task)
    task = tasks.MetaAggregated(task_list,
                                reward_aggregator='sum',
                                termination_criterion='all')

    renderers = {
        'image':
        spriteworld_renderers.PILRenderer(image_size=(64, 64), anti_aliasing=5)
    }

    config = {
        'task': task,
        'action_space': action_spaces.SelectMove(scale=0.5),
        'renderers': renderers,
        'init_sprites': sprite_gen,
        'max_episode_length': 50,
        'metadata': {
            'name': os.path.basename(__file__),
            'mode': mode
        }
    }
    return config
Пример #12
0
 def load_data(self):
     # download if not avaiable
     file_path = os.path.join(self.path, self.fname)
     if not os.path.exists(file_path):
         os.makedirs(self.path, exist_ok=True)
         print(f'file not found, downloading from {self.url} ...')
         from urllib import request
         url = self.url
         request.urlretrieve(url, file_path)
     with open(file_path) as data:
         self.csv_dict = load_csv(data, sequence=self.sequence_len)
     self.orig_num = [32, 32, 6, 40, 4, 1, 1, 1]
     self.dsprites = {
         'x':
         np.linspace(0.2, 0.8, self.orig_num[0]),
         'y':
         np.linspace(0.2, 0.8, self.orig_num[1]),
         'scale':
         np.linspace(0, 0.5, self.orig_num[2] + 1)[1:],
         'angle':
         np.linspace(0, 360, self.orig_num[3], dtype=np.int,
                     endpoint=False),
         'shape': ['square', 'triangle', 'star_4', 'spoke_4'],
         'c0': [1.],
         'c1': [1.],
         'c2': [1.]
     }
     distributions = []
     for key in self.dsprites.keys():
         distributions.append(distribs.Discrete(key, self.dsprites[key]))
     self.factor_dist = distribs.Product(distributions)
     self.renderer = spriteworld_renderers.PILRenderer(image_size=(64, 64),
                                                       anti_aliasing=5,
                                                       color_to_rgb=rgb)
     if self.area_filter:
         keep_idxes = []
         print(len(self.csv_dict['x']))
         for i in range(self.sequence_len):
             x = pd.Series(np.array(self.csv_dict['area'])[:, i])
             keep_idxes.append(
                 x.between(x.quantile(self.area_filter / 2),
                           x.quantile(1 - (self.area_filter / 2))))
         for k in self.csv_dict.keys():
             y = pd.Series(self.csv_dict[k])
             self.csv_dict[k] = np.array(
                 [x for x in y[np.logical_and(*keep_idxes)]])
         print(len(self.csv_dict['x']))
     if self.natural_discrete:
         num_bins = self.orig_num[:3]
         self.lab_encs = {}
         print('num_bins', num_bins)
         for i, key in enumerate(['x', 'y', 'area']):
             count, bin_edges = np.histogram(np.array(
                 self.csv_dict[key]).flatten().tolist(),
                                             bins=num_bins[i])
             bin_left, bin_right = bin_edges[:-1], bin_edges[1:]
             bin_centers = bin_left + (bin_right - bin_left) / 2
             new_data = []
             old_shape = np.array(self.csv_dict[key]).shape
             lab_enc = preprocessing.LabelEncoder()
             if key == 'area':
                 self.lab_encs['scale'] = lab_enc.fit(
                     np.sqrt(bin_centers / (64**2)))
             else:
                 self.lab_encs[key] = lab_enc.fit(bin_centers / 64)
             for j in range(self.sequence_len):
                 differences = (
                     np.array(self.csv_dict[key])[:, j].reshape(1, -1) -
                     bin_centers.reshape(-1, 1))
                 new_data.append([
                     bin_centers[x]
                     for x in np.abs(differences).argmin(axis=0)
                 ])
             self.csv_dict[key] = np.swapaxes(new_data, 0, 1)
             assert old_shape == np.array(self.csv_dict[key]).shape
             assert len(np.unique(np.array(
                 self.csv_dict[key]).flatten())) == num_bins[i]
         for i, key in enumerate(['angle', 'shape', 'c0', 'c1', 'c2']):
             lab_enc = preprocessing.LabelEncoder()
             self.lab_encs[key] = lab_enc.fit(self.dsprites[key])
         assert self.lab_encs.keys() == self.dsprites.keys()
     self.factor_sizes = [
         len(np.unique(np.array(self.csv_dict['x']).flatten())),
         len(np.unique(np.array(self.csv_dict['y']).flatten())),
         len(np.unique(np.array(self.csv_dict['area']).flatten())), 40, 4,
         1, 1, 1
     ]
     print(self.factor_sizes)
     self.latent_factor_indices = list(range(5))
     self.num_factors = len(self.latent_factor_indices)
     self.observation_factor_indices = [
         i for i in range(self.num_factors)
         if i not in self.latent_factor_indices
     ]
     self.mapping = {'square': 0, 'triangle': 1, 'star_4': 2, 'spoke_4': 3}
Пример #13
0
def get_config(mode=None):
    """Generate environment config.

    Args:
        mode: Unused task mode.

    Returns:
        config: Dictionary defining task/environment configuration. Can be fed
            as kwargs to physics_environment.PhysicsEnvironment.
    """

    # Factor distribution for the fixed center star
    center_factors = distribs.Product([
        distribs.Discrete('x', [0.5]),
        distribs.Discrete('y', [0.5]),
        distribs.Discrete('shape', ['circle']),
        distribs.Discrete('scale', [0.15]),
        distribs.Discrete('c0', [3.]),
        distribs.Discrete('c1', [1.]),
        distribs.Discrete('c2', [1.]),
        distribs.Continuous('mass', 1, 3),
    ])

    center_sprite_gen = generate_sprites.generate_sprites(center_factors,
                                                          num_sprites=1)

    # Factor distributions for the orbiting sprites
    orbit_factors = distribs.Product([
        distribs.Continuous('x', 0.1, 0.9),
        distribs.Continuous('y', 0.1, 0.9),
        distribs.Discrete('shape', ['star_4', 'star_5', 'star_6']),
        distribs.Discrete('scale', [0.08]),
        distribs.Continuous('c0', 0, 1),
        distribs.Continuous('c1', 0.5, 1.),
        distribs.Discrete('c2', [1.]),
        distribs.Continuous('x_vel', -0.03, 0.03),
        distribs.Continuous('y_vel', -0.03, 0.03),
        distribs.Discrete('mass', [1]),
    ])

    orbit_sprite_gen = generate_sprites.generate_sprites(orbit_factors,
                                                         num_sprites=4)

    sprite_gen = sprite_generators.chain_generators(center_sprite_gen,
                                                    orbit_sprite_gen)

    force = forces.Gravity(gravity_constant=-0.0001,
                           distance_for_max_force=0.05)
    adjacency_matrix = {
        (0, 1): force,
        (0, 2): force,
        (0, 3): force,
        (0, 4): force,
    }

    graph_generator = graph_generators.AdjacencyMatrix(
        adjacency_matrix=adjacency_matrix, symmetric=False)

    renderers = {
        'image':
        spriteworld_renderers.PILRenderer(image_size=(64, 64), anti_aliasing=5)
    }

    config = {
        'graph_generators': (graph_generator, ),
        'renderers': renderers,
        'init_sprites': sprite_gen,
        'episode_length': 30,
        'bounce_off_walls': True,
        'physics_steps_per_env_step': 10,
        'metadata': {
            'name': os.path.basename(__file__),
            'mode': mode
        }
    }
    return config