def test_stability_tower_is_constructible(): tp = TowerPlanner() obj_a = Object('a', Dimensions(0.1,0.1,0.1), 1, Position(0,0,0), Color(0,1,1)) obj_b = Object('b', Dimensions(0.3,0.1,0.1), 3, Position(0,0,0), Color(1,0,1)) obj_c = Object('c', Dimensions(0.1,0.1,0.2), 2, Position(0,0,0), Color(1,1,0)) # the single block is constructible obj_a.pose = Pose(Position(0, 0, 0.05), ZERO_ROT) assert tp.tower_is_constructible([obj_a]) # this is constructible obj_b.pose = Pose(Position(0, 0, 0.15), ZERO_ROT) assert tp.tower_is_constructible([obj_a, obj_b]) # this is unconstructible obj_b.pose = Pose(Position(0.06, 0, 0.15), ZERO_ROT) assert not tp.tower_is_constructible([obj_a, obj_b]) # it becomes stable, but remains unconstructible when we add another block obj_c.pose = Pose(Position(0.0, 0, 0.3), ZERO_ROT) assert not tp.tower_is_constructible([obj_a, obj_b, obj_c]) # this tower is constructible, but not stable obj_b.pose = Pose(Position(0, 0.04, 0.15), ZERO_ROT) obj_c.pose = Pose(Position(0, 0.08, 0.3), ZERO_ROT) assert tp.tower_is_constructible([obj_a, obj_b, obj_c])
def evaluate_predictions(fname): with open(fname, 'rb') as handle: results = pickle.load(handle) tp = TowerPlanner(stability_mode='contains') # Index this as [stable][cog_stable][pw_stable] for ix, (towers, labels, preds) in enumerate(results): correct = [[[0, 0], [0, 0]], [[0, 0], [0, 0]]] total = [[[0, 0], [0, 0]], [[0, 0], [0, 0]]] # Check the tower stability type. for tower, label, pred in zip(towers, labels, preds): blocks = to_blocks(tower) cog_stable = tp.tower_is_cog_stable(blocks) pw_stable = tp.tower_is_constructible(blocks) stable = tp.tower_is_stable(blocks) if stable != label: print('WAT', stable, label) #assert stable == label total[stable][cog_stable][pw_stable] += 1 if (pred > 0.5) == label: correct[stable][cog_stable][pw_stable] += 1 print(total) print('%d Towers' % (ix + 2)) for stable in [0, 1]: for cog_stable in [0, 1]: for pw_stable in [0, 1]: if ix == 0 and pw_stable != stable: continue acc = correct[stable][cog_stable][pw_stable] / total[ stable][cog_stable][pw_stable] print( 'Stable: %d\tCOG_Stable: %d\tPW_Stable: %d\tAcc: %f' % (stable, cog_stable, pw_stable, acc))
def pairwise_stable(tower): tp = TowerPlanner(stability_mode='contains') return tp.tower_is_constructible(tower)