Exemple #1
0
def _mock_train(**args):
    outdir = tempfile.mkdtemp()
    parser = setup_args()
    parser.set_defaults(
        model_file=os.path.join(outdir, "model"),
        **args,
    )
    stdout = io.StringIO()
    with contextlib.redirect_stdout(stdout):
        tl = TrainLoop(parser.parse_args(print_args=False))
        valid, test = tl.train()

    shutil.rmtree(outdir)
    return stdout.getvalue(), valid, test
Exemple #2
0
def _mock_train(outdir=None, keepoutdir=False, override=None, **args):
    if not outdir:
        outdir = tempfile.mkdtemp()
    parser = setup_args()
    parser.set_defaults(
        model_file=os.path.join(outdir, "model"),
        **args,
    )
    stdout = io.StringIO()
    with contextlib.redirect_stdout(stdout):
        opt = parser.parse_args(print_args=False)
        if override:
            opt['override'] = override
        tl = TrainLoop(opt)
        valid, test = tl.train()
    if not keepoutdir:
        shutil.rmtree(outdir)
    return stdout.getvalue(), valid, test
Exemple #3
0
    def test_hogwild_train(self):
        """Test the trainer eval with numthreads > 1 and batchsize in [1,2,3]."""
        parser = setup_args()
        NUM_EXS = 500
        parser.set_defaults(
            task='tasks.repeat:RepeatTeacher:{}'.format(1),
            evaltask='tasks.repeat:RepeatTeacher:{}'.format(NUM_EXS),
            model='repeat_label',
            num_examples=-1,
            display_examples=False,
            num_epochs=10,
        )

        old_out = sys.stdout
        output = display_output()
        try:
            sys.stdout = output
            for nt in [2, 5, 10]:
                parser.set_defaults(numthreads=nt)
                for bs in [1, 2, 3]:
                    parser.set_defaults(batchsize=bs)
                    parser.set_defaults(batch_sort=(bs % 2 == 0))
                    tl = TrainLoop(parser)
                    report_valid, report_test = tl.train()
                    # test final valid and test evals
                    self.assertEqual(report_valid['exs'], NUM_EXS)
                    self.assertEqual(report_test['exs'], NUM_EXS)

                    report_full, _world = run_eval(tl.agent,
                                                   tl.opt,
                                                   'valid',
                                                   max_exs=-1,
                                                   valid_world=tl.valid_world)
                    self.assertEqual(report_full['exs'], NUM_EXS)
                    report_part, _world = run_eval(tl.agent,
                                                   tl.opt,
                                                   'valid',
                                                   max_exs=NUM_EXS / 5,
                                                   valid_world=tl.valid_world)
                    self.assertTrue(report_part['exs'] < NUM_EXS)
        finally:
            # restore sys.stdout
            sys.stdout = old_out