Beispiel #1
0
def capture_print(log_file_path = 'logs/dump/%T-log.txt', print_to_console=True):
    """
    :param log_file_path: Path of file to print to, if (state and to_file).  If path does not start with a "/", it will
        be relative to the data directory.  You can use placeholders such as %T, %R, ... in the path name (see format
        filename)
    :param print_to_console:
    :param print_to_console: Also continue printing to console.
    :return: The absolute path to the log file.
    """
    local_log_file_path = get_artemis_data_path(log_file_path)
    logger = CaptureStdOut(log_file_path=local_log_file_path, print_to_console=print_to_console)
    logger.__enter__()
    sys.stdout = logger
    sys.stderr = logger
    return local_log_file_path
Beispiel #2
0
def record_experiment(identifier='%T-%N',
                      name='unnamed',
                      print_to_console=True,
                      show_figs=None,
                      save_figs=True,
                      saved_figure_ext='.fig.pkl',
                      use_temp_dir=False,
                      date=None,
                      prefix=None):
    """
    :param identifier: The string that uniquely identifies this experiment record.  Convention is that it should be in
        the format
    :param name: Base-name of the experiment
    :param print_to_console: If True, print statements still go to console - if False, they're just rerouted to file.
    :param show_figs: Show figures when the experiment produces them.  Can be:
        'hang': Show and hang
        'draw': Show but keep on going
        False: Don't show figures
    """
    # Note: matplotlib imports are internal in order to avoid trouble for people who may import this module without having
    # a working matplotlib (which can occasionally be tricky to install).
    if date is None:
        date = datetime.now()
    identifier = format_filename(file_string=identifier,
                                 base_name=name,
                                 current_time=date)

    if show_figs is None:
        show_figs = 'draw' if is_test_mode() else 'hang'

    assert show_figs in ('hang', 'draw', False)

    if use_temp_dir:
        experiment_directory = tempfile.mkdtemp()
        atexit.register(lambda: shutil.rmtree(experiment_directory))
    else:
        experiment_directory = get_local_experiment_path(identifier)

    make_dir(experiment_directory)
    this_record = ExperimentRecord(experiment_directory)

    # Create context that sets the current experiment record
    # and the context which captures stdout (print statements) and logs them.
    contexts = [
        hold_current_experiment_record(this_record),
        CaptureStdOut(log_file_path=os.path.join(experiment_directory,
                                                 'output.txt'),
                      print_to_console=print_to_console,
                      prefix=prefix)
    ]

    if is_matplotlib_imported():
        from artemis.plotting.manage_plotting import WhatToDoOnShow
        # Add context that modifies how matplotlib figures are shown.
        contexts.append(WhatToDoOnShow(show_figs))
        if save_figs:
            from artemis.plotting.saving_plots import SaveFiguresOnShow
            # Add context that saves figures when show is called.
            contexts.append(
                SaveFiguresOnShow(
                    path=os.path.join(experiment_directory, 'fig-%T-%L' +
                                      saved_figure_ext)))

    with nested(*contexts):
        yield this_record
Beispiel #3
0
def cifar10(epochs, log_interval, pretrain_path, restore_path, use_batchnorm,
            quantize_activations, quantize_weights, _seed, _run):

    print('get_local_dir', get_local_dir("data/cifar10"))

    assert (pretrain_path is None) + (
        restore_path is None) > 0, "Only pretrain_path or restore_path"

    exp_dir = get_experiment_dir(ex.path, _run)

    print("Starting Experiment in {}".format(exp_dir))
    with CaptureStdOut(log_file_path=os.path.join(exp_dir, "output.txt") if
                       not False else os.path.join(exp_dir, "val_output.txt")):
        try:
            # Data
            train_loader, test_loader = get_data_train_test()

            # Model
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
            model = get_model()

            # Configure

            model, optimizer, best_val_acc, start_epoch, best_val_epoch = configure_starting_point(
                model=model)
            model = model.to(device)

            # Misc
            train_writer = SummaryWriter(log_dir=exp_dir)
            hooks = TBHook(model, train_writer,
                           start_epoch * len(train_loader),
                           torch.cuda.device_count(), log_interval)

            scheduler = get_lr_scheduler(optimizer=optimizer)

            gc.collect()
            model = torch.nn.DataParallel(model)
            gc.collect()

            best_epoch = 0
            best_val_acc = -np.inf
            criterion = get_loss_criterion()

            if torch.cuda.is_available():
                _, test_acc = test(model, test_loader)
                print('Test acc before training ', test_acc)

            ##########################################################################################
            start_epoch = 0
            val_list = []
            for epoch in range(start_epoch, start_epoch + epochs + 1):

                ##########################################################################################

                print('EPOCH: ', epoch)

                ##########################################################################################
                print('Training')
                train_loss, train_acc = train_epoch(model=model,
                                                    train_loader=train_loader,
                                                    optimizer=optimizer,
                                                    epoch=epoch,
                                                    train_writer=train_writer,
                                                    log_interval=log_interval,
                                                    criterion=criterion)
                ##########################################################################################

                train_loss_eval, train_acc_eval = test(model, train_loader)
                train_writer.add_scalar("Validation/TrainLoss",
                                        train_loss_eval,
                                        epoch * len(train_loader))
                train_writer.add_scalar("Validation/TrainAccuracy",
                                        train_acc_eval,
                                        epoch * len(train_loader))
                print(
                    "Epoch {}, Training Eval Loss: {:.4f}, Training Eval Acc: {:.4f}"
                    .format(epoch, train_loss_eval, train_acc_eval))

                val_loss, val_acc = test(model, test_loader)

                ##########################################################################################

                try:
                    scheduler.step(epoch=epoch)
                except TypeError:
                    scheduler.step()

                ##########################################################################################

                train_writer.add_scalar("Validation/Loss", val_loss,
                                        epoch * len(train_loader))
                train_writer.add_scalar("Validation/Accuracy", val_acc,
                                        epoch * len(train_loader))
                train_writer.add_scalar("Others/LearningRate",
                                        optimizer.param_groups[0]["lr"],
                                        epoch * len(train_loader))
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_epoch = epoch
                    save_state(model=model.state_dict(),
                               optimizer=optimizer.state_dict(),
                               epoch=epoch,
                               best_val_acc=best_val_acc,
                               best_epoch=best_epoch,
                               save_path=os.path.join(exp_dir,
                                                      "best_model.pt"))
                print(
                    "Epoch {}, Validation Loss: {:.4f},\033[1m Validation Acc: {:.4f}\033[0m , Best Val Acc: {:.4f} at EP {}"
                    .format(epoch, val_loss, val_acc, best_val_acc,
                            best_epoch))

                # saving the last model
                save_state(model=model.state_dict(),
                           optimizer=optimizer.state_dict(),
                           epoch=epoch,
                           best_val_acc=best_val_acc,
                           best_epoch=best_epoch,
                           save_path=os.path.join(exp_dir, "model.pt"))

                # save all models, print real bops
                folder_to_save = 'mpdnn_models'
                if not os.path.exists(folder_to_save):
                    os.makedirs(folder_to_save)

                name_to_save = 'model_' + str(epoch) + '.pt'
                print('Epoch: ', epoch)
                print('Val ACC ', val_acc)
                if epoch % 5 == 0:
                    print('Plot weights')
                    plot_weights(model, epoch)
                    print('Saving a model ')
                    save_state(model=model.state_dict(),
                               optimizer=optimizer.state_dict(),
                               epoch=epoch,
                               best_val_acc=best_val_acc,
                               best_epoch=best_epoch,
                               save_path=folder_to_save + '/' + name_to_save)

            print("Early Stopping Epoch {} with Val Acc {:.4f} ".format(
                best_epoch, best_val_acc))

        except Exception:
            write_error_trace(exp_dir)
            raise