def testAddIfExist(self): objset = sg.ObjectSet(n_epoch=1) epoch_now = 0 objset.add(sg.Object(when='now'), epoch_now, add_if_exist=False) objset.add(sg.Object(when='now'), epoch_now, add_if_exist=False) self.assertEqual(len(objset), 1)
def testAddLast1(self): objset = sg.ObjectSet(n_epoch=100) objset.add(sg.Object(when='now'), epoch_now=0) objset.add(sg.Object(when='now', deletable=True), epoch_now=1) objset.add(sg.Object(when='last1'), epoch_now=2, add_if_exist=False) self.assertEqual(len(objset), 1)
def testDeletable(self): objset = sg.ObjectSet(n_epoch=1) epoch_now = 0 objset.add(sg.Object(when='now', deletable=True), epoch_now) objset.add(sg.Object(when='now'), epoch_now, add_if_exist=True) # The first object should have been deleted self.assertEqual(len(objset), 1)
def testSelectLast(self): objset = sg.ObjectSet(n_epoch=2) objset.add(sg.Object(when='now'), epoch_now=0) objset.add(sg.Object(when='now'), epoch_now=1) objset.add(sg.Object(when='now'), epoch_now=1, add_if_exist=True) epoch_now = 1 self.assertEqual(2, len(objset.select(epoch_now, when='latest'))) self.assertEqual(1, len(objset.select(epoch_now, when='last1')))
def testSelectBackTrack(self): objset = sg.ObjectSet(n_epoch=100) objset.add(sg.Object(when='now'), epoch_now=5) epoch_now = 10 l1 = len(objset.select(epoch_now, when='latest', n_backtrack=4)) self.assertEqual(l1, 0) objset = sg.ObjectSet(n_epoch=100) objset.add(sg.Object(when='now'), epoch_now=5) epoch_now = 10 l1 = len(objset.select(epoch_now, when='latest', n_backtrack=5)) self.assertEqual(l1, 1)
def testMergeObject(self): obj1 = sg.Object([sg.Shape('circle')]) obj2 = sg.Object([sg.Color('red')]) merged = obj1.merge(obj2) self.assertTrue(merged) obj1 = sg.Object([sg.Shape('circle')]) obj2 = sg.Object([sg.Shape('a')]) merged = obj1.merge(obj2) self.assertFalse(merged)
def testAddObjectInSpace(self): objset = sg.ObjectSet(n_epoch=1) space1 = sg.Space([(0, 1), (0, 0.5)]) space2 = sg.Space([(0, 1), (0.5, 1)]) space3 = sg.Space([(0, 1), (0, 0.5)]) epoch_now = 0 objset.add(sg.Object([space1], when='now'), epoch_now) self.assertEqual(len(objset), 1) objset.add(sg.Object([space2], when='now'), epoch_now) self.assertEqual(len(objset), 2) objset.add(sg.Object([space3], when='now'), epoch_now) self.assertEqual(len(objset), 2)
def testDecodeRenderPool(self): img_size = 112 grid_size = 7 prefs = constants.get_prefs(grid_size) loc_xy = sg.Loc([0.8, 0.3]) n_epoch = 5 objset = sg.ObjectSet(n_epoch=n_epoch) obj = sg.Object([loc_xy, sg.Shape('square'), sg.Color('blue')], when='now') objset.add(obj, epoch_now=0) movie = sg.render(objset, img_size=img_size) frame = movie.sum(axis=-1, keepdims=True) in_imgs = tf.placeholder('float', [None, img_size, img_size, 1]) out = tf.contrib.layers.avg_pool2d(in_imgs, 16, 16) with tf.Session() as sess: out_ = sess.run(out, feed_dict={in_imgs: frame}) out_ = np.reshape(out_, (n_epoch, -1)) out_ = (out_.T / (1e-7 + out_.sum(axis=1))).T loc_decoded = np.dot(out_, prefs)[0] print('Input loc ' + str(loc_xy)) print('Decoded loc' + str(loc_decoded)) dist = ((loc_decoded[0] - loc_xy.value[0])**2 + (loc_decoded[1] - loc_xy.value[1])**2) self.assertLess(dist, 0.01)
def testSelectNow(self): objset = sg.ObjectSet(n_epoch=1) epoch_now = 0 objset.add(sg.Object(when='now'), epoch_now) subset = objset.select(epoch_now, when='now') self.assertEqual(len(subset), 1)
def testAndGuessObjset(self): objs1 = tg.Select(when='last1') objs2 = tg.Select(when='now') s1 = tg.GetShape(objs1) s2 = tg.GetShape(objs2) c1 = tg.GetColor(objs1) c2 = tg.GetColor(objs2) task = tg.Task(tg.And(tg.IsSame(s1, s2), tg.IsSame(c1, c2))) objset = sg.ObjectSet(n_epoch=10) obj0 = sg.Object([sg.Color('green'), sg.Shape('square')], when='now') obj1 = sg.Object([sg.Color('red'), sg.Shape('circle')], when='now') objset.add(obj0, epoch_now=0) objset.add(obj1, epoch_now=1) objset = task.guess_objset(objset, epoch_now=2, should_be=True) obj2 = objset.last_added_obj self.assertEqual(obj1.color.value, obj2.color.value) self.assertEqual(obj1.shape.value, obj2.shape.value)
def testSelectGetExpectedInputShouldBeTrue2(self): objset = sg.ObjectSet(n_epoch=10) select = tg.Select(color=sg.Color('red'), when='now') should_be = [sg.Object([sg.Shape('circle')])] objset, loc, color, space = select.get_expected_input(should_be, objset, epoch_now=1) objs = select(objset, epoch_now=1) self.assertEqual(objs[0].shape, sg.Shape('circle'))
def testSelectGetExpectedInputShouldBeTrue1(self): objset = sg.ObjectSet(n_epoch=10) select = tg.Select(color=sg.Color('red'), when='now') should_be = [sg.Object([sg.Loc([0.5, 0.5])])] objset, loc, color, space = select.get_expected_input(should_be, objset, epoch_now=1) objs = select(objset, epoch_now=1) self.assertTupleEqual(objs[0].loc.value, (0.5, 0.5))
def generate_objset(task, n_epoch=30, distractor=True): objset = sg.ObjectSet(n_epoch=n_epoch) # Guess objects for epoch_now in range(n_epoch): if distractor: objset.add(sg.Object(when='now', deletable=True), epoch_now) objset = task.get_expected_input(objset, epoch_now) return objset
def generate_objset(self, n_epoch, n_distractor=1, average_memory_span=2): """Generate object set.""" if n_epoch < 4: raise ValueError( 'Number of epoch {:d} is less than 4'.format(n_epoch)) shape1, shape2 = sg.sample_shape(2) shape3 = sg.random_shape() objset = sg.ObjectSet(n_epoch=n_epoch) sample1 = sg.Object([self._color1, shape1], when='now') distractor1 = sg.Object([self._color3, shape3], when='now') test1 = sg.Object([self._color2, shape1], when='now') test2 = sg.Object([self._color2, shape2], when='now') objset.add(sample1, epoch_now=1) # sample epoch objset.add(distractor1, epoch_now=2) # delay epoch objset.add(test1, epoch_now=3) # test epoch objset.add(test2, epoch_now=3) # test epoch return objset
def testAddWhenNone(self): objset = sg.ObjectSet(n_epoch=100, n_max_backtrack=5) objset.add(sg.Object(when=None), epoch_now=0) self.assertEqual(objset.last_added_obj.epoch, [0, 100]) l1 = len(objset.select(epoch_now=10, when='now')) l2 = len(objset.select(epoch_now=10, when='latest')) l3 = len(objset.select(epoch_now=10, when='last1')) self.assertEqual(l1, 1) self.assertEqual(l2, 1) self.assertEqual(l3, 1)
def testSelectNowLoc(self): objset = sg.ObjectSet(n_epoch=1) epoch_now = 0 loc = sg.Loc([0.3, 0.3]) objset.add(sg.Object([loc], when='now'), epoch_now) space1 = sg.Space([(0.2, 0.4), (0.1, 0.5)]) space2 = sg.Space([(0.5, 0.7), (0.1, 0.5)]) subset1 = objset.select(epoch_now, space=space1, when='now') subset2 = objset.select(epoch_now, space=space2, when='now') self.assertEqual(len(subset1), 1) self.assertEqual(len(subset2), 0)
def testGetCall(self): objset = sg.ObjectSet(n_epoch=10) obj1 = tg.Select(color=sg.Color('red'), when='now') color1 = tg.GetColor(obj1) epoch_now = 1 color1_eval = color1(objset, epoch_now) self.assertEqual(color1_eval, const.INVALID) objset.add(sg.Object([sg.Color('red')], when='now'), epoch_now) color1_eval = color1(objset, epoch_now) self.assertEqual(color1_eval, sg.Color('red'))
def testSelectGetExpectedInputShouldBeTrue3(self): objset = sg.ObjectSet(n_epoch=10) select = tg.Select(color=sg.Color('red'), when='now') should_be = [sg.Object(when='now')] objset, loc, color, space = select.get_expected_input(should_be, objset, epoch_now=1) objset, loc, color, space = select.get_expected_input(should_be, objset, epoch_now=1) objs = select(objset, epoch_now=1) self.assertEqual(len(objs), 1)
def testSelectGetExpectedInputShouldBeTrue4(self): objset = sg.ObjectSet(n_epoch=10) select = tg.Select(loc=sg.Loc([0.5, 0.5]), when='now', space_type='left') should_be = [sg.Object(when='now')] objset, loc, color, space = select.get_expected_input(should_be, objset, epoch_now=1) objs = select(objset, epoch_now=1) self.assertLess(objs[0].loc.value[0], 0.5) self.assertTrue(isinstance(loc, tg.Skip))
def testShiftObjset(self): objset = sg.ObjectSet(n_epoch=2) epoch_now = 0 objset.add(sg.Object(when='now'), epoch_now) objset.shift(1) self.assertEqual(objset.n_epoch, 3) self.assertEqual(objset.set[0].epoch, [1, 2]) subset = objset.select(1, when='now') self.assertEqual(len(subset), 1) objset.shift(-2) self.assertEqual(len(objset), 0)
def generate_objset(self, n_epoch, n_distractor=1, average_memory_span=2): if n_epoch < 4: raise ValueError( 'Number of epoch {:d} is less than 4'.format(n_epoch)) objset = sg.ObjectSet(n_epoch=n_epoch) n_sample = 2 sample_shapes = sg.sample_shape(n_sample) for i in range(n_sample): obj = sg.Object([self._color1, sample_shapes[i]], when='now') objset.add(obj, epoch_now=1) obj = sg.Object([sg.another_color(self._color1)], when='now') objset.add(obj, epoch_now=1) # distractor if random.random() < 0.5: shape3 = random.choice(sample_shapes) else: shape3 = sg.another_shape(sample_shapes) test1 = sg.Object([self._color1, shape3], when='now') objset.add(test1, epoch_now=3) # test epoch return objset
def testIsSameGuessObjsetWithDistractors(self): objs1 = tg.Select(shape=sg.Shape('square'), when='last1') objs2 = tg.Select(shape=sg.Shape('circle'), when='last1') attr1 = tg.GetColor(objs1) attr2 = tg.GetColor(objs2) task = tg.Task(tg.IsSame(attr1, attr2)) n_epoch = 10 objset = sg.ObjectSet(n_epoch=n_epoch) obj1 = sg.Object( [sg.Color('green'), sg.Shape('square')], when='now', deletable=True) objset.add(obj1, 0, add_if_exist=True) obj1 = sg.Object([sg.Color('red'), sg.Shape('circle')], when='now', deletable=True) objset.add(obj1, 0, add_if_exist=True) objset = task.guess_objset(objset, 0, should_be=True) objset.add_distractor(1) objset = task.guess_objset(objset, 1, should_be=True) self.assertTrue(task(objset, 1))
def testGetSpaceCall(self): objs0 = tg.Select(color=sg.Color('red'), when='last1') objs1 = tg.Select(loc=tg.GetLoc(objs0), when='now', space_type='left') task1 = tg.Task(tg.Exist(objs1)) objs2 = tg.Select(loc=tg.GetLoc(objs0), when='now', space_type='right') task2 = tg.Task(tg.Exist(objs2)) objs3 = tg.Select(loc=tg.GetLoc(objs0), when='now', space_type='top') task3 = tg.Task(tg.Exist(objs3)) objs4 = tg.Select(loc=tg.GetLoc(objs0), when='now', space_type='bottom') task4 = tg.Task(tg.Exist(objs4)) objset = sg.ObjectSet(n_epoch=2) obj1 = sg.Object([sg.Loc([0.5, 0.5]), sg.Color('red')], when='now') objset.add(obj1, epoch_now=0) obj1 = sg.Object([sg.Loc([0.2, 0.3])], when='now') objset.add(obj1, epoch_now=1) self.assertTrue(task1(objset, epoch_now=1)) self.assertFalse(task2(objset, epoch_now=1)) self.assertTrue(task3(objset, epoch_now=1)) self.assertFalse(task4(objset, epoch_now=1))
def generate_objset(self, n_epoch, n_distractor=1, average_memory_span=2): """Manual generate objset. By design this function will not be balanced because the network always answer False during the sample epoch. """ if n_epoch < 4: raise ValueError( 'Number of epoch {:d} is less than 4'.format(n_epoch)) objset = sg.ObjectSet(n_epoch=n_epoch) n_sample = random.choice([1, 2, 3, 4]) sample_attrs = sg.sample_colorshape(n_sample + 1) for attrs in sample_attrs[:n_sample]: obj = sg.Object(attrs, when='now') objset.add(obj, epoch_now=1) if random.random() < 0.5: attr = random.choice(sample_attrs[:n_sample]) else: attr = sample_attrs[-1] test1 = sg.Object(attr, when='now') objset.add(test1, epoch_now=3) # test epoch return objset
def get_expected_input(self, should_be): if self.objs.when != 'now': raise ValueError(""" Guess objset is not supported for the Exist class for when other than now""") if should_be is None: should_be = random.random() > 0.5 if should_be: should_be = [sg.Object()] else: should_be = [] return should_be
def generate_objset(self, n_epoch, n_distractor=1, average_memory_span=2): """Generate object set. The task has 4 epochs: Fixation, Sample, Delay, and Test. During sample, one sample object is shown. During test, two test objects are shown, one of them will match the color of the sample object Args: n_epoch: int Returns: objset: ObjectSet instance. Raises: ValueError: when n_epoch is less than 4, the minimum epoch number for this task """ if n_epoch < 4: raise ValueError( 'Number of epoch {:d} is less than 4'.format(n_epoch)) color1, color2 = sg.sample_color(2) color3 = sg.random_color() objset = sg.ObjectSet(n_epoch=n_epoch) sample1 = sg.Object([color1, self._shape1], when='now') distractor1 = sg.Object([color3, self._shape3], when='now') test1 = sg.Object([color1, self._shape2], when='now') test2 = sg.Object([color2, self._shape2], when='now') objset.add(sample1, epoch_now=1) # sample epoch objset.add(distractor1, epoch_now=2) # delay epoch objset.add(test1, epoch_now=3) # test epoch objset.add(test2, epoch_now=3) # test epoch return objset
def testSelectCall(self): objset = sg.ObjectSet(n_epoch=10) attrs = [sg.Loc([0.5, 0.5]), sg.Shape('circle'), sg.Color('red')] objset.add(sg.Object(attrs, when='now'), epoch_now=1) select = tg.Select(color=sg.Color('red'), when='now') self.assertTrue(select(objset, epoch_now=1)) self.assertFalse(select(objset, epoch_now=2)) select = tg.Select(color=sg.Color('blue'), when='now') self.assertFalse(select(objset, epoch_now=1)) select = tg.Select(shape=sg.Shape('circle'), when='now') self.assertTrue(select(objset, epoch_now=1)) self.assertFalse(select(objset, epoch_now=2)) select = tg.Select(loc=sg.Loc([0.6, 0.6]), when='now', space_type='left') self.assertTrue(select(objset, epoch_now=1)) select = tg.Select(loc=sg.Loc([0.6, 0.6]), when='now', space_type='top') self.assertTrue(select(objset, epoch_now=1)) select = tg.Select(color=sg.Color('red'), when='last1') self.assertFalse(select(objset, epoch_now=1)) self.assertTrue(select(objset, epoch_now=2)) select = tg.Select(color=sg.Color('red'), when='latest') self.assertTrue(select(objset, epoch_now=1)) self.assertTrue(select(objset, epoch_now=2)) attrs = [sg.Loc([0.7, 0.7]), sg.Shape('square'), sg.Color('red')] objset.add(sg.Object(attrs, when='now'), epoch_now=1) select = tg.Select(color=sg.Color('red'), when='latest') self.assertEqual(len(select(objset, epoch_now=1)), 2)
def testGetTimeCall(self): obj1 = tg.Select(color=sg.Color('red'), when='latest') time1 = tg.GetTime(obj1) n_epoch = 10 objset = sg.ObjectSet(n_epoch=n_epoch, n_max_backtrack=100) epoch_add = 1 time1_eval = time1(objset, epoch_add) self.assertEqual(time1_eval, const.INVALID) objset.add(sg.Object([sg.Color('red')], when='now'), epoch_add) for epoch_now in range(epoch_add, n_epoch - 1): time1_eval = time1(objset, epoch_now) self.assertEqual(time1_eval, epoch_add)
def testDecodeRender(self): img_size = 112 prefs = constants.get_prefs(img_size) loc_xy = sg.Loc([0.2, 0.8]) objset = sg.ObjectSet(n_epoch=1) obj = sg.Object([loc_xy, sg.Shape('square'), sg.Color('blue')], when='now') objset.add(obj, epoch_now=0) movie = sg.render(objset, img_size=img_size) self.assertEqual(list(movie.shape), [1, img_size, img_size, 3]) movie = movie.sum(axis=-1) # sum across color movie /= movie.sum() movie = np.reshape(movie, (1, -1)) loc_decoded = np.dot(movie, prefs)[0] dist = ((loc_decoded[0] - loc_xy.value[0])**2 + (loc_decoded[1] - loc_xy.value[1])**2) self.assertLess(dist, 0.01)
def testAddFixLoc(self): loc = [0.2, 0.8] objset = sg.ObjectSet(n_epoch=1) obj = sg.Object([sg.Loc(loc)], when='now') objset.add(obj, epoch_now=0) self.assertEqual(list(obj.loc.value), loc)