def test_null_allowed(self):
     p = Program.parse('LIST|TAIL,0|ACCESS,1,0')
     expected = [
         constraint.ListConstraint(lmin=1,
                                   int_constraints=[
                                       constraint.IntConstraint(0, l - 1)
                                       for l in range(constraint.L + 1)
                                   ]),
         constraint.IntConstraint(0, 256),
         constraint.IntConstraint()
     ]
     output_constraint = constraint.IntConstraint()
     actual = constraint.propagate_constraints(p, output_constraint)
     self.assertEqual(expected, actual)
    def test_propagate(self):
        stmts = [
            (impl.MAP, (impl.TIMES2, 0)),
            (impl.FILTER, (impl.GT0, 1)),
            (impl.MAP, (impl.MINUS1, 2)),
        ]

        p = Program([LIST], stmts)

        output_constraint = constraint.ListConstraint(
            1, 4, [constraint.IntConstraint(-5, 3)] * 5)
        actual = constraint.propagate_constraints(p, output_constraint)[0]
        expected = constraint.ListConstraint(
            1, 4, [constraint.IntConstraint(-2, 2)] * 5)

        self.assertEqual(expected, actual)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--infile', type=str)
    parser.add_argument('--outfile', type=str)
    parser.add_argument('--nb_examples', type=int, default=5)
    parser.add_argument('--nb_inputs', type=int, default=3)
    args = parser.parse_args()

    with open(args.infile) as in_fh:
        line_count = sum([1 for _ in in_fh])

    with open(args.infile) as in_fh:
        xdata = []
        ydata = []
        programs = []
        pbar = tqdm.tqdm(total=line_count)
        for line in in_fh:
            pbar.update(1)
            program = Program.parse(line.rstrip())
            try:
                xdata.append(
                    get_program_row(program, args.nb_examples, args.nb_inputs))
                ydata.append(get_attribute_vec(program))
            except:
                print('prog:', program)
                print('constraint:')
                for x in constraint.propagate_constraints(program):
                    print(x)
                raise
            programs.append(program)

    x = collections.defaultdict(list)
    for row in xdata:
        for k, v in row.items():
            x[k].append(v)
    np.savez(args.outfile, y=ydata, **x)