Ejemplo n.º 1
0
    control_data_package = DataPackage(regular_invaders, AutoEncodeSelect())

    run_fac = SimpleRunFac.resume(r'C:\data\runs\489',
                                  co_ord_conv_data_package)

    batch_size = 64
    epochs = 100

    for model, opt, loss_fn, data_package, trainer, tester, run in run_fac:
        dev, train, test, selector = data_package.loaders(
            batch_size=batch_size)

        model.register_forward_hook(input_viewer.view_input)
        model.decoder.register_forward_hook(latent_viewer.view_input)
        model.register_forward_hook(output_viewer.view_output)
        register_tb(run)

        for epoch in tqdm(run.for_epochs(epochs), 'epochs', epochs):

            #epoch.register_after_hook(write_histogram)
            epoch.execute_before(epoch)

            trainer.train(model, opt, loss_fn, train, selector, run, epoch)

            handles = Handles()
            handles += loss_fn.register_hook(tb_test_loss_term)
            tester.test(model, loss_fn, test, selector, run, epoch)
            handles.remove()
            epoch.execute_after(epoch)
            run.save()
Ejemplo n.º 2
0
            Params(MseKldLoss),
            co_ord_conv_data_package,
            run_name='full_v1'))

    #run_fac = SimpleRunFac.resume(r'C:\data\runs\549', co_ord_conv_data_package)
    batch_size = 64
    epochs = 30

    for model, opt, loss_fn, data_package, trainer, tester, run in run_fac:
        dev, train, test, selector = data_package.loaders(
            batch_size=batch_size)

        model.register_forward_hook(input_viewer.view_input)
        model.decoder.register_forward_hook(latent_viewer.view_input)
        model.register_forward_hook(output_viewer.view_output)
        register_tb(run, config)
        gui = GUIProgressMeter(description='training')
        trainer.register_after_hook(gui.update_train)
        tester.register_after_hook(gui.update_test)

        for epoch in run.for_epochs(epochs):

            #epoch.register_after_hook(write_histogram)
            epoch.register_after_hook(gui.end_epoch)
            epoch.execute_before(epoch)

            handles = Handles()

            trainer.train(model, opt, loss_fn, dev, selector, run, epoch)

            handles += loss_fn.register_hook(TB().tb_test_loss_term)