Example #1
0
 def testSequence(self):
   sprites = [
       sprite_lib.Sprite(x=np.random.rand(), y=np.random.rand())
       for _ in range(5)
   ]
   renderer = handcrafted.SpriteFactors()
   renderer.render(sprites=sprites)
Example #2
0
  def testAttributesSingleton(self, x, y, shape, c0, c1, c2, scale, angle):
    sprite = sprite_lib.Sprite(
        x=x, y=y, shape=shape, c0=c0, c1=c1, c2=c2, scale=scale, angle=angle)
    renderer = handcrafted.SpriteFactors()
    outputs = renderer.render(sprites=[sprite])[0]

    self.assertEqual(outputs['shape'], const.ShapeType[shape].value)
    for (name, value) in (('x', x), ('y', y), ('c0', c0), ('c1', c1),
                          ('c2', c2), ('scale', scale), ('angle', angle)):
      self.assertAlmostEqual(outputs[name], value, delta=1e-4)
Example #3
0
  def testObservationSpec(self, num_sprites, factors):
    sprites = [sprite_lib.Sprite() for _ in range(num_sprites)]
    renderer = handcrafted.SpriteFactors(factors=factors)
    renderer.render(sprites=sprites)
    obs_spec = renderer.observation_spec()

    for v in obs_spec[0].values():
      self.assertEqual(v.shape, ())

    obs_spec_keys = [set(x) for x in obs_spec]
    self.assertSequenceEqual(obs_spec_keys, num_sprites * [set(factors)])
Example #4
0
  def testAttributesTwoSprites(self):
    x = [0.5, 0.3]
    y = [0.4, 0.8]
    shape = ['square', 'spoke_4']
    c0 = [0, 200]
    c1 = [255, 100]
    c2 = [0, 200]
    scale = [0.2, 0.3]
    angle = [0, 120]
    x_vel = [0.0, 0.1]
    y_vel = [-0.2, 0.05]

    sprites = []
    for i in range(2):
      sprites.append(
          sprite_lib.Sprite(
              x=x[i],
              y=y[i],
              shape=shape[i],
              c0=c0[i],
              c1=c1[i],
              c2=c2[i],
              scale=scale[i],
              angle=angle[i],
              x_vel=x_vel[i],
              y_vel=y_vel[i]))

    renderer = handcrafted.SpriteFactors()
    outputs = renderer.render(sprites=sprites)

    for i in range(2):
      self.assertEqual(outputs[i]['shape'], const.ShapeType[shape[i]].value)
      for (name, value) in (('x', x), ('y', y), ('c0', c0), ('c1', c1),
                            ('c2', c2), ('scale', scale), ('angle', angle),
                            ('x_vel', x_vel), ('y_vel', y_vel)):
        self.assertAlmostEqual(outputs[i][name], value[i], delta=1e-4)
Example #5
0
 def testFactorSubset(self, num_sprites, factors):
   sprites = [sprite_lib.Sprite() for _ in range(num_sprites)]
   renderer = handcrafted.SpriteFactors(factors=factors)
   outputs = renderer.render(sprites=sprites)
   output_keys = [set(x) for x in outputs]
   self.assertSequenceEqual(output_keys, num_sprites * [set(factors)])
Example #6
0
 def testOutputLength(self, num_sprites):
   sprites = [sprite_lib.Sprite() for _ in range(num_sprites)]
   renderer = handcrafted.SpriteFactors()
   outputs = renderer.render(sprites=sprites)
   self.assertLen(outputs, num_sprites)
Example #7
0
  def testSingleton(self):
    sprite = sprite_lib.Sprite(
        x=0.1, y=0.3, shape='square', scale=0.5, c0=0, c1=0, c2=255)

    renderer = handcrafted.SpriteFactors()
    renderer.render(sprites=[sprite])
Example #8
0
 def testWrongFactors(self):
   handcrafted.SpriteFactors(factors=('x', 'y', 'scale'))
   with self.assertRaises(ValueError):
     handcrafted.SpriteFactors(factors=('position', 'scale'))
   with self.assertRaises(ValueError):
     handcrafted.SpriteFactors(factors=('x', 'y', 'size'))