def test_experiments_function_additions():

    with experiment_testing_context():

        for rec in my_xxxyyy_test_experiment.get_variant_records(flat=True):
            rec.delete()

        r1=my_xxxyyy_test_experiment.run()
        r2=my_xxxyyy_test_experiment.get_variant('a2').run()
        with pytest.raises(Exception):
            my_xxxyyy_test_experiment.get_variant(b=17).run()
        r3 = my_xxxyyy_test_experiment.get_variant(b=17).get_latest_record()

        assert r1.get_log() == 'xxx\n'
        assert r2.get_log() == 'yyy\n'

        assert get_oneline_result_string(my_xxxyyy_test_experiment.get_latest_record()) == '3bbb'
        assert get_oneline_result_string(my_xxxyyy_test_experiment.get_variant('a2').get_latest_record()) == '4bbb'
        assert get_oneline_result_string(my_xxxyyy_test_experiment.get_variant(b=17).get_latest_record()) == '<No result has been saved>'

        with CaptureStdOut() as cap:
            my_xxxyyy_test_experiment.show(my_xxxyyy_test_experiment.get_latest_record())
        assert cap.read() == '3aaa\n'

        with CaptureStdOut() as cap:
            my_xxxyyy_test_experiment.compare([r1, r2])
        assert cap.read() == 'my_xxxyyy_test_experiment: 3, my_xxxyyy_test_experiment.a2: 4\n'

        print('='*100+'\n ARGTABLE \n'+'='*100)
        print_experiment_record_argtable([r1, r2, r3])

        print('='*100+'\n SHOW \n'+'='*100)
        compare_experiment_records([r1, r2, r3])
Ejemplo n.º 2
0
def test_nested_capture():

    with CaptureStdOut() as cap1:
        print('a')
        with CaptureStdOut() as cap2:
            print('b')
        print('c')

    assert cap2.read()=='b\n'
    assert cap1.read()=='a\nb\nc\n'
Ejemplo n.º 3
0
def test_nested_capture():

    with CaptureStdOut() as cap1:
        print 'a'
        with CaptureStdOut() as cap2:
            print 'b'
        print 'c'

    assert cap2.read() == 'b\n'
    assert cap1.read() == 'a\nb\nc\n'
Ejemplo n.º 4
0
def test_capture_prefix():

    with CaptureStdOut() as cap1:
        print('a')
        with CaptureStdOut(prefix='abc:') as cap2:
            print('b')
        print('c')

    print ('Done')
    assert cap2.read()=='b\n'
    assert cap1.read()=='a\nabc:b\nc\n'
Ejemplo n.º 5
0
def record_experiment(identifier='%T-%N',
                      name='unnamed',
                      print_to_console=True,
                      show_figs=None,
                      save_figs=True,
                      saved_figure_ext='.pdf',
                      use_temp_dir=False,
                      date=None):
    """
    :param identifier: The string that uniquely identifies this experiment record.  Convention is that it should be in
        the format
    :param name: Base-name of the experiment
    :param print_to_console: If True, print statements still go to console - if False, they're just rerouted to file.
    :param show_figs: Show figures when the experiment produces them.  Can be:
        'hang': Show and hang
        'draw': Show but keep on going
        False: Don't show figures
    """
    # Note: matplotlib imports are internal in order to avoid trouble for people who may import this module without having
    # a working matplotlib (which can occasionally be tricky to install).
    if date is None:
        date = datetime.now()
    identifier = format_filename(file_string=identifier,
                                 base_name=name,
                                 current_time=date)

    if show_figs is None:
        show_figs = 'draw' if is_test_mode() else 'hang'

    assert show_figs in ('hang', 'draw', False)

    if use_temp_dir:
        experiment_directory = tempfile.mkdtemp()
        atexit.register(lambda: shutil.rmtree(experiment_directory))
    else:
        experiment_directory = get_local_path(
            'experiments/{identifier}'.format(identifier=identifier))

    make_dir(experiment_directory)
    from artemis.plotting.manage_plotting import WhatToDoOnShow
    global _CURRENT_EXPERIMENT_RECORD  # Register
    _CURRENT_EXPERIMENT_RECORD = ExperimentRecord(experiment_directory)
    capture_context = CaptureStdOut(log_file_path=os.path.join(
        experiment_directory, 'output.txt'),
                                    print_to_console=print_to_console)
    show_context = WhatToDoOnShow(show_figs)
    if save_figs:
        from artemis.plotting.saving_plots import SaveFiguresOnShow
        save_figs_context = SaveFiguresOnShow(
            path=os.path.join(experiment_directory, 'fig-%T-%L' +
                              saved_figure_ext))
        with capture_context, show_context, save_figs_context:
            yield _CURRENT_EXPERIMENT_RECORD
    else:
        with capture_context, show_context:
            yield _CURRENT_EXPERIMENT_RECORD
    _CURRENT_EXPERIMENT_RECORD = None  # Deregister
Ejemplo n.º 6
0
def test_proper_persistent_print_usage():
    """
    Here is the best, cleanest way to use persistent print
    :return:
    """
    print 'ddd'
    with CaptureStdOut() as ps:
        print 'fff'
        print 'ggg'
    print 'hhh'
    assert ps.read() == 'fff\nggg\n'
Ejemplo n.º 7
0
def test_proper_persistent_print_file_logging():

    log_file_path = get_local_path('tests/test_log.txt')
    with CaptureStdOut(log_file_path) as ps:
        print 'fff'
        print 'ggg'
    print 'hhh'
    assert ps.read() == 'fff\nggg\n'

    # You can verify that the log has also been written.
    log_path = ps.get_log_file_path()
    with open(log_path) as f:
        txt = f.read()
    assert txt == 'fff\nggg\n'
Ejemplo n.º 8
0
def record_experiment(identifier='%T-%N', name='unnamed', print_to_console=True, show_figs=None,
                      save_figs=True, saved_figure_ext='.fig.pkl', use_temp_dir=False, date=None, prefix=None):
    """
    :param identifier: The string that uniquely identifies this experiment record.  Convention is that it should be in
        the format
    :param name: Base-name of the experiment
    :param print_to_console: If True, print statements still go to console - if False, they're just rerouted to file.
    :param show_figs: Show figures when the experiment produces them.  Can be:
        'hang': Show and hang
        'draw': Show but keep on going
        False: Don't show figures
    """
    # Note: matplotlib imports are internal in order to avoid trouble for people who may import this module without having
    # a working matplotlib (which can occasionally be tricky to install).
    if date is None:
        date = datetime.now()
    identifier = format_filename(file_string=identifier, base_name=name, current_time=date)

    if show_figs is None:
        show_figs = 'draw' if is_test_mode() else 'hang'

    assert show_figs in ('hang', 'draw', False)

    if use_temp_dir:
        experiment_directory = tempfile.mkdtemp()
        atexit.register(lambda: shutil.rmtree(experiment_directory))
    else:
        experiment_directory = get_local_experiment_path(identifier)

    make_dir(experiment_directory)
    this_record = ExperimentRecord(experiment_directory)

    # Create context that sets the current experiment record
    # and the context which captures stdout (print statements) and logs them.
    contexts = [
        hold_current_experiment_record(this_record),
        CaptureStdOut(log_file_path=os.path.join(experiment_directory, 'output.txt'), print_to_console=print_to_console, prefix=prefix)
        ]

    if is_matplotlib_imported():
        from artemis.plotting.manage_plotting import WhatToDoOnShow
        # Add context that modifies how matplotlib figures are shown.
        contexts.append(WhatToDoOnShow(show_figs))
        if save_figs:
            from artemis.plotting.saving_plots import SaveFiguresOnShow
            # Add context that saves figures when show is called.
            contexts.append(SaveFiguresOnShow(path=os.path.join(experiment_directory, 'fig-%T-%L' + saved_figure_ext)))

    with nested(*contexts):
        yield this_record
Ejemplo n.º 9
0
def capture_print(log_file_path='logs/dump/%T-log.txt', print_to_console=True):
    """
    :param log_file_path: Path of file to print to, if (state and to_file).  If path does not start with a "/", it will
        be relative to the data directory.  You can use placeholders such as %T, %R, ... in the path name (see format
        filename)
    :param print_to_console:
    :param print_to_console: Also continue printing to console.
    :return: The absolute path to the log file.
    """
    local_log_file_path = get_artemis_data_path(log_file_path)
    logger = CaptureStdOut(log_file_path=local_log_file_path,
                           print_to_console=print_to_console)
    logger.__enter__()
    sys.stdout = logger
    sys.stderr = logger
    return local_log_file_path
Ejemplo n.º 10
0
def show_experiment_records(records, parallel_text=None, hang_notice = None, show_logs=True, truncate_logs=None, truncate_result=10000, header_width=100, show_result ='deep', hang=True):
    """
    Show the console logs, figures, and results of a collection of experiments.

    :param records:
    :param parallel_text:
    :param hang_notice:
    :return:
    """
    if isinstance(records, ExperimentRecord):
        records = [records]
    if parallel_text is None:
        parallel_text = len(records)>1
    if len(records)==0:
        print '... No records to show ...'
    else:
        strings = [get_record_full_string(rec, show_logs=show_logs, show_result=show_result, truncate_logs=truncate_logs,
                    truncate_result=truncate_result, header_width=header_width, include_bottom_border=False) for rec in records]
    has_matplotlib_figures = any(loc.endswith('.pkl') for rec in records for loc in rec.get_figure_locs())
    if has_matplotlib_figures:
        from matplotlib import pyplot as plt
        from artemis.plotting.saving_plots import interactive_matplotlib_context
        for rec in records:
            rec.show_figures(hang=False)
        if hang_notice is not None:
            print hang_notice

        with interactive_matplotlib_context(not hang):
            plt.show()

    if any(rec.get_experiment().display_function is not None for rec in records):
        from artemis.plotting.saving_plots import interactive_matplotlib_context
        with interactive_matplotlib_context():
            for i, rec in enumerate(records):
                with CaptureStdOut(print_to_console=False) as cap:
                    display_experiment_record(rec)
                if cap != '':
                    # strings[i] += '{subborder} Result Display {subborder}\n{out} \n{border}'.format(subborder='-'*20, out=cap.read(), border='='*50)
                    strings[i] += section_with_header('Result Display', cap.read(), width=header_width, bottom_char='=')

    if parallel_text:
        print side_by_side(strings, max_linewidth=128)
    else:
        for string in strings:
            print string

    return has_matplotlib_figures
Ejemplo n.º 11
0
def test_indent_print():

    with CaptureStdOut() as cap:
        print('aaa')
        print('bbb')
        with IndentPrint():
            print('ccc')
            print('ddd')
            with IndentPrint():
                print('eee')
                print('fff')
            print('ggg')
            print('hhh')
        print('iii')
        print('jjj')

    assert '\n'+cap.read() == _desired
Ejemplo n.º 12
0
def test_indent_print():

    with CaptureStdOut() as cap:
        print 'aaa'
        print 'bbb'
        with IndentPrint():
            print 'ccc'
            print 'ddd'
            with IndentPrint():
                print 'eee'
                print 'fff'
            print 'ggg'
            print 'hhh'
        print 'iii'
        print 'jjj'

    assert '\n' + cap.read() == _desired
Ejemplo n.º 13
0
def test_duplicate_headers_when_no_records_bug_is_gone():
    # There was a bug that duplicated the headers when there were no records.  This test verifies that it's gone.
    @experiment_function
    def my_simdfdscds(a=1):

        print('xxxxx')
        print('yyyyy')
        return a + 2

    for r in my_simdfdscds.get_records():
        r.delete()

    with CaptureStdOut() as cap:
        my_simdfdscds.browse(command='q')

    string = cap.read()

    assert string.count('Start Time') == 1
Ejemplo n.º 14
0
def cifar10(epochs, log_interval, pretrain_path, restore_path, use_batchnorm,
            quantize_activations, quantize_weights, _seed, _run):

    print('get_local_dir', get_local_dir("data/cifar10"))

    assert (pretrain_path is None) + (
        restore_path is None) > 0, "Only pretrain_path or restore_path"

    exp_dir = get_experiment_dir(ex.path, _run)

    print("Starting Experiment in {}".format(exp_dir))
    with CaptureStdOut(log_file_path=os.path.join(exp_dir, "output.txt") if
                       not False else os.path.join(exp_dir, "val_output.txt")):
        try:
            # Data
            train_loader, test_loader = get_data_train_test()

            # Model
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
            model = get_model()

            # Configure

            model, optimizer, best_val_acc, start_epoch, best_val_epoch = configure_starting_point(
                model=model)
            model = model.to(device)

            # Misc
            train_writer = SummaryWriter(log_dir=exp_dir)
            hooks = TBHook(model, train_writer,
                           start_epoch * len(train_loader),
                           torch.cuda.device_count(), log_interval)

            scheduler = get_lr_scheduler(optimizer=optimizer)

            gc.collect()
            model = torch.nn.DataParallel(model)
            gc.collect()

            best_epoch = 0
            best_val_acc = -np.inf
            criterion = get_loss_criterion()

            if torch.cuda.is_available():
                _, test_acc = test(model, test_loader)
                print('Test acc before training ', test_acc)

            ##########################################################################################
            start_epoch = 0
            val_list = []
            for epoch in range(start_epoch, start_epoch + epochs + 1):

                ##########################################################################################

                print('EPOCH: ', epoch)

                ##########################################################################################
                print('Training')
                train_loss, train_acc = train_epoch(model=model,
                                                    train_loader=train_loader,
                                                    optimizer=optimizer,
                                                    epoch=epoch,
                                                    train_writer=train_writer,
                                                    log_interval=log_interval,
                                                    criterion=criterion)
                ##########################################################################################

                train_loss_eval, train_acc_eval = test(model, train_loader)
                train_writer.add_scalar("Validation/TrainLoss",
                                        train_loss_eval,
                                        epoch * len(train_loader))
                train_writer.add_scalar("Validation/TrainAccuracy",
                                        train_acc_eval,
                                        epoch * len(train_loader))
                print(
                    "Epoch {}, Training Eval Loss: {:.4f}, Training Eval Acc: {:.4f}"
                    .format(epoch, train_loss_eval, train_acc_eval))

                val_loss, val_acc = test(model, test_loader)

                ##########################################################################################

                try:
                    scheduler.step(epoch=epoch)
                except TypeError:
                    scheduler.step()

                ##########################################################################################

                train_writer.add_scalar("Validation/Loss", val_loss,
                                        epoch * len(train_loader))
                train_writer.add_scalar("Validation/Accuracy", val_acc,
                                        epoch * len(train_loader))
                train_writer.add_scalar("Others/LearningRate",
                                        optimizer.param_groups[0]["lr"],
                                        epoch * len(train_loader))
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_epoch = epoch
                    save_state(model=model.state_dict(),
                               optimizer=optimizer.state_dict(),
                               epoch=epoch,
                               best_val_acc=best_val_acc,
                               best_epoch=best_epoch,
                               save_path=os.path.join(exp_dir,
                                                      "best_model.pt"))
                print(
                    "Epoch {}, Validation Loss: {:.4f},\033[1m Validation Acc: {:.4f}\033[0m , Best Val Acc: {:.4f} at EP {}"
                    .format(epoch, val_loss, val_acc, best_val_acc,
                            best_epoch))

                # saving the last model
                save_state(model=model.state_dict(),
                           optimizer=optimizer.state_dict(),
                           epoch=epoch,
                           best_val_acc=best_val_acc,
                           best_epoch=best_epoch,
                           save_path=os.path.join(exp_dir, "model.pt"))

                # save all models, print real bops
                folder_to_save = 'mpdnn_models'
                if not os.path.exists(folder_to_save):
                    os.makedirs(folder_to_save)

                name_to_save = 'model_' + str(epoch) + '.pt'
                print('Epoch: ', epoch)
                print('Val ACC ', val_acc)
                if epoch % 5 == 0:
                    print('Plot weights')
                    plot_weights(model, epoch)
                    print('Saving a model ')
                    save_state(model=model.state_dict(),
                               optimizer=optimizer.state_dict(),
                               epoch=epoch,
                               best_val_acc=best_val_acc,
                               best_epoch=best_epoch,
                               save_path=folder_to_save + '/' + name_to_save)

            print("Early Stopping Epoch {} with Val Acc {:.4f} ".format(
                best_epoch, best_val_acc))

        except Exception:
            write_error_trace(exp_dir)
            raise