def testSequence(self): sprites = [ sprite_lib.Sprite(x=np.random.rand(), y=np.random.rand()) for _ in range(5) ] renderer = handcrafted.SpriteFactors() renderer.render(sprites=sprites)
def testAttributesSingleton(self, x, y, shape, c0, c1, c2, scale, angle): sprite = sprite_lib.Sprite( x=x, y=y, shape=shape, c0=c0, c1=c1, c2=c2, scale=scale, angle=angle) renderer = handcrafted.SpriteFactors() outputs = renderer.render(sprites=[sprite])[0] self.assertEqual(outputs['shape'], const.ShapeType[shape].value) for (name, value) in (('x', x), ('y', y), ('c0', c0), ('c1', c1), ('c2', c2), ('scale', scale), ('angle', angle)): self.assertAlmostEqual(outputs[name], value, delta=1e-4)
def testObservationSpec(self, num_sprites, factors): sprites = [sprite_lib.Sprite() for _ in range(num_sprites)] renderer = handcrafted.SpriteFactors(factors=factors) renderer.render(sprites=sprites) obs_spec = renderer.observation_spec() for v in obs_spec[0].values(): self.assertEqual(v.shape, ()) obs_spec_keys = [set(x) for x in obs_spec] self.assertSequenceEqual(obs_spec_keys, num_sprites * [set(factors)])
def testAttributesTwoSprites(self): x = [0.5, 0.3] y = [0.4, 0.8] shape = ['square', 'spoke_4'] c0 = [0, 200] c1 = [255, 100] c2 = [0, 200] scale = [0.2, 0.3] angle = [0, 120] x_vel = [0.0, 0.1] y_vel = [-0.2, 0.05] sprites = [] for i in range(2): sprites.append( sprite_lib.Sprite( x=x[i], y=y[i], shape=shape[i], c0=c0[i], c1=c1[i], c2=c2[i], scale=scale[i], angle=angle[i], x_vel=x_vel[i], y_vel=y_vel[i])) renderer = handcrafted.SpriteFactors() outputs = renderer.render(sprites=sprites) for i in range(2): self.assertEqual(outputs[i]['shape'], const.ShapeType[shape[i]].value) for (name, value) in (('x', x), ('y', y), ('c0', c0), ('c1', c1), ('c2', c2), ('scale', scale), ('angle', angle), ('x_vel', x_vel), ('y_vel', y_vel)): self.assertAlmostEqual(outputs[i][name], value[i], delta=1e-4)
def testFactorSubset(self, num_sprites, factors): sprites = [sprite_lib.Sprite() for _ in range(num_sprites)] renderer = handcrafted.SpriteFactors(factors=factors) outputs = renderer.render(sprites=sprites) output_keys = [set(x) for x in outputs] self.assertSequenceEqual(output_keys, num_sprites * [set(factors)])
def testOutputLength(self, num_sprites): sprites = [sprite_lib.Sprite() for _ in range(num_sprites)] renderer = handcrafted.SpriteFactors() outputs = renderer.render(sprites=sprites) self.assertLen(outputs, num_sprites)
def testSingleton(self): sprite = sprite_lib.Sprite( x=0.1, y=0.3, shape='square', scale=0.5, c0=0, c1=0, c2=255) renderer = handcrafted.SpriteFactors() renderer.render(sprites=[sprite])
def testWrongFactors(self): handcrafted.SpriteFactors(factors=('x', 'y', 'scale')) with self.assertRaises(ValueError): handcrafted.SpriteFactors(factors=('position', 'scale')) with self.assertRaises(ValueError): handcrafted.SpriteFactors(factors=('x', 'y', 'size'))