示例#1
0
def get_file(relative_name, url=None, data_transformation=None):

    relative_folder, file_name = os.path.split(relative_name)
    local_folder = get_local_path(relative_folder)

    try:  # Best way to see if folder exists already - avoids race condition between processes
        os.makedirs(local_folder)
    except OSError:
        pass

    full_filename = os.path.join(local_folder, file_name)

    if not os.path.exists(full_filename):
        assert url is not None, "No local copy of '%s' was found, and you didn't provide a URL to fetch it from" % (
            full_filename, )

        print 'Downloading file from url: "%s"...' % (url, )
        response = urllib2.urlopen(url)
        data = response.read()
        print '...Done.'

        if data_transformation is not None:
            print 'Processing downloaded data...'
            data = data_transformation(data)
        with open(full_filename, 'w') as f:
            f.write(data)
    return full_filename
示例#2
0
def smart_file(location, use_cache=False, make_dir=False):
    """
    :param location: Specifies where the file is.
        If it's formatted as a url, it's downloaded.
        If it begins with a "/", it's assumed to be a local path.
        Otherwise, it is assumed to be referenced relative to the data directory.
    :param use_cache: If True, and the location is a url, make a local cache of the file for future use (note: if the
        file at this url changes, the cached file will not).
    :param make_dir: Make the directory for this file, if it does not exist.
    :yield: The local path to the file.
    """
    its_a_url = is_url(location)
    if its_a_url:
        assert not make_dir, "We cannot 'make the directory' for a URL"
        if use_cache:
            local_path = get_file_and_cache(location)
        else:
            local_path = get_temp_file(location)
    else:
        local_path = get_local_path(location)
        if make_dir:
            make_file_dir(local_path)

    yield local_path

    if its_a_url and not use_cache:
        os.remove(local_path)
示例#3
0
def show_saved_figure(relative_loc):
    """
    Display a saved figure.

    Behaviour: this simply opens a window with the figure, and then continues executing the code.

    :param relative_loc: Relative path (within the data directory) to the figure.  Treated as an absolute path
        if it begins with "/"
    :return:
    """
    _, ext = os.path.splitext(relative_loc)
    abs_loc = get_local_path(relative_loc)
    assert os.path.exists(
        abs_loc), '"%s" did not exist.  That is odd.' % (abs_loc, )
    if ext in ('.jpg', '.png', '.tif'):
        try:
            from PIL import Image
            Image.open(abs_loc).show()
        except ImportError:
            ARTEMIS_LOGGER.error(
                "Cannot display image '%s', because PIL is not installed.  Go pip install pillow to use this.  Currently it is a soft requirement."
            )
    else:
        import webbrowser
        webbrowser.open('file://' + abs_loc)
示例#4
0
def smart_save(obj, relative_path, remove_file_after = False):
    """
    Save an object locally.  How you save it depends on its extension.
    Extensions currently supported:
        pkl: Pickle file.
        That is all.
    :param obj: Object to save
    :param relative_path: Path to save it, relative to "Data" directory.  The following placeholders can be used:
        %T - ISO time
        %R - Current Experiment Record Identifier (includes experiment time and experiment name)
    :param remove_file_after: If you're just running a test, it's good to verify that you can save, but you don't
        actually want to leave a file behind.  If that's the case, set this argument to True.
    """
    if '%T' in relative_path:
        iso_time = datetime.now().isoformat().replace(':', '.').replace('-', '.')
        relative_path = relative_path.replace('%T', iso_time)
    if '%R' in relative_path:
        from artemis.fileman.experiment_record import get_current_experiment_id
        relative_path = relative_path.replace('%R', get_current_experiment_id())
    _, ext = os.path.splitext(relative_path)
    local_path = get_local_path(relative_path, make_local_dir=True)

    print 'Saved object <%s at %s> to file: "%s"' % (obj.__class__.__name__, hex(id(object)), local_path)
    if ext=='.pkl':
        with open(local_path, 'w') as f:
            pickle.dump(obj, f)
    elif ext=='.pdf':
        obj.savefig(local_path)
    else:
        raise Exception("No method exists yet to save '.%s' files.  Add it!" % (ext, ))

    if remove_file_after:
        os.remove(local_path)

    return local_path
示例#5
0
def save_current_figure():
    print("Attempting to save figure")
    fig = plt.gcf()
    file_name = format_filename(file_string='%T', current_time=datetime.now())
    save_path = get_local_path(
        'output/{file_name}.pdf'.format(file_name=file_name))
    print("Current figure saved to {}".format(save_path))
    save_figure(fig, path=save_path)
def test_save_and_show_figure():

    fig = plt.figure()
    plt.imshow(np.random.randn(10, 10))
    plt.title('Test Figure')
    path = get_local_path('tests/test_fig.pdf')
    save_figure(fig, path=path)
    show_saved_figure(path)
def test_save_and_show_figure_3():

    fig = plt.figure()
    plt.imshow(np.random.randn(10, 10))
    plt.title('Test Figure')
    path = get_local_path('tests/test_fig.with.strangely.formatted.ending')
    path = save_figure(fig, path=path, ext='pdf')
    show_saved_figure(path)
示例#8
0
def record_experiment(identifier='%T-%N', name = 'unnamed', info = '', print_to_console = True, show_figs = None,
            save_figs = True, saved_figure_ext = '.pdf', use_temp_dir = False):
    """
    :param identifier: The string that uniquely identifies this experiment record.  Convention is that it should be in
        the format
    :param name: Base-name of the experiment
    :param print_to_console: If True, print statements still go to console - if False, they're just rerouted to file.
    :param show_figs: Show figures when the experiment produces them.  Can be:
        'hang': Show and hang
        'draw': Show but keep on going
        False: Don't show figures
    """
    # Note: matplotlib imports are internal in order to avoid trouble for people who may import this module without having
    # a working matplotlib (which can occasionally be tricky to install).

    identifier = format_filename(file_string = identifier, base_name=name, current_time = datetime.now())

    if show_figs is None:
        show_figs = 'draw' if is_test_mode() else 'hang'

    assert show_figs in ('hang', 'draw', False)

    if use_temp_dir:
        experiment_directory = tempfile.mkdtemp()
        atexit.register(lambda: shutil.rmtree(experiment_directory))
    else:
        experiment_directory = get_local_path('experiments/{identifier}'.format(identifier=identifier))

    make_dir(experiment_directory)
    make_file_dir(experiment_directory)
    log_file_name = os.path.join(experiment_directory, 'output.txt')
    log_capture_context = PrintAndStoreLogger(log_file_path = log_file_name, print_to_console = print_to_console)
    log_capture_context.__enter__()
    from artemis.plotting.manage_plotting import WhatToDoOnShow
    blocking_show_context = WhatToDoOnShow(show_figs)
    blocking_show_context.__enter__()
    if save_figs:
        from artemis.plotting.saving_plots import SaveFiguresOnShow
        figure_save_context = SaveFiguresOnShow(path = os.path.join(experiment_directory, 'fig-%T-%L'+saved_figure_ext))
        figure_save_context.__enter__()

    _register_current_experiment(name, identifier)

    global _CURRENT_EXPERIMENT_RECORD
    _CURRENT_EXPERIMENT_RECORD = ExperimentRecord(experiment_directory)
    _CURRENT_EXPERIMENT_RECORD.add_info('Name: %s' % (name, ))
    _CURRENT_EXPERIMENT_RECORD.add_info('Identifier: %s' % (identifier, ))
    _CURRENT_EXPERIMENT_RECORD.add_info('Directory: %s' % (_CURRENT_EXPERIMENT_RECORD.get_dir(), ))
    yield _CURRENT_EXPERIMENT_RECORD
    _CURRENT_EXPERIMENT_RECORD = None

    blocking_show_context.__exit__(None, None, None)
    log_capture_context.__exit__(None, None, None)
    if save_figs:
        figure_save_context.__exit__(None, None, None)

    _deregister_current_experiment()
示例#9
0
def test_simple_rsync():
    from_path = get_local_path(relative_path="tmp/tests/", make_local_dir=True)
    with open(os.path.join(from_path, "test1"), "wb"):
        pass
    with open(os.path.join(from_path, "test2"), "wb"):
        pass
    remote_path = "~/PycharmProjects/Distributed-VI/"

    assert simple_rsync(local_path=from_path, remote_path=remote_path, ip_address=ip_address, verbose=True)
    shutil.rmtree(from_path)
示例#10
0
def get_relative_link_from_relative_path(relative_path):
    """
    Given a local path to a file in the data folder, return the relative link that will access it from
    the server.

    To do this, we make a soft-link from the server directory to the data folder - this way we can
    browse our data folder from Jupyter, which is nice.

    :param relative_path: Relative path (from within Data folder)
    :return: A string representing the relative link to get to that file.
    """
    true_local_data_dir = get_local_path()

    local_path = get_local_path(relative_path)
    launcher = 'tree' if os.path.isdir(local_path) else 'files'

    if not os.path.lexists(SERVER_RELATIVE_DATA_DIR):
        os.symlink(true_local_data_dir, SERVER_RELATIVE_DATA_DIR)
    return os.path.join('/', launcher, DATA_FOLDER_NAME, relative_path)
示例#11
0
def clear_experiment_records_with_name(experiment_name=None):
    """
    Clear all experiment results.
    :param matching_expression:
    :return:
    """
    ids = get_all_experiment_ids(_get_matching_template_from_experiment_name(experiment_name))
    paths = [os.path.join(get_local_path('experiments'), identifier) for identifier in ids]
    for p in paths:
        shutil.rmtree(p)
示例#12
0
def test_unpack_tar_gz():

    if os.path.exists(get_local_path('tests/test_tar_zip')):
        shutil.rmtree(get_local_path('tests/test_tar_zip'))
    if os.path.exists(get_local_path('tests/test_tar_zip.tar.gz')):
        os.remove(get_local_path('tests/test_tar_zip.tar.gz'))

    for _ in xrange(2):  # (Second time to check caching)

        local_file = get_file_in_archive(
            relative_path= 'tests/test_tar_zip',
            url = 'https://drive.google.com/uc?export=download&id=0B4IfiNtPKeSAbmp6VEVJdjdSSlE',
            subpath = 'testzip/test_file.txt'
            )

        with open(local_file) as f:
            txt = f.read()

        assert txt == 'blah blah blah'
示例#13
0
def get_all_experiment_ids(expr = None):
    """
    :param expr: A regexp for matching experiments
        None if you just want all of them
    :return: A list of experiment identifiers.
    """

    expdir = get_local_path('experiments')
    experiments = [e for e in os.listdir(expdir) if os.path.isdir(os.path.join(expdir, e))]
    if expr is not None:
        experiments = [e for e in experiments if re.match(expr, e)]
    return experiments
示例#14
0
def write_port_to_file(port):
    atexit.register(remove_port_file)
    port_file_path = get_local_path("tmp/plot_server/port.info",
                                    make_local_dir=True)
    if os.path.exists(port_file_path):
        print(
            "port.info file already exists. This might either mean that you are running another plotting server in the background and want to start a second one.\nIn this case ignore "
            "this message. Otherwise a previously run plotting server crashed without cleaning up afterwards. \nIn this case, please manually delete the file at {}"
            .format(port_file_path),
            file=sys.stderr)
    with open(port_file_path, 'wb') as f:
        pickle.dump(port, f)
示例#15
0
def clear_experiments():
    # Credit: http://stackoverflow.com/questions/185936/delete-folder-contents-in-python
    folder = get_local_path('experiments')
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print(e)
示例#16
0
def setup_web_plotting(update_period=1.):
    plot_directory = get_local_path(relative_path="tmp/web_backend/%s/" %
                                    (str(uuid.uuid4()), ),
                                    make_local_dir=True)  # Temporary directory
    atexit.register(clean_up, plot_dir=plot_directory)
    _start_plotting_server(plot_directory=plot_directory,
                           update_period=update_period)
    set_draw_callback(
        TimedFigureSaver(os.path.join(plot_directory, 'artemis_figure.png'),
                         update_period=update_period))
    set_show_callback(
        TimedFigureSaver(os.path.join(plot_directory, 'artemis_figure.png'),
                         update_period=update_period))
示例#17
0
def test_rsync():
    options = ["-r"]
    username = get_config_value(".artemisrc", section=ip_address, option="username")

    from_path = get_local_path(relative_path="tmp/tests/", make_local_dir=True)
    with open(os.path.join(from_path, "test1"), "wb"):
        pass
    with open(os.path.join(from_path, "test2"), "wb"):
        pass

    to_path = "%s@%s:/home/%s/temp/"%(username, ip_address, username)
    assert rsync(options, from_path, to_path)
    shutil.rmtree(from_path)
def test_proper_persistent_print_file_logging():

    log_file_path = get_local_path('tests/test_log.txt')
    with CaptureStdOut(log_file_path) as ps:
        print 'fff'
        print 'ggg'
    print 'hhh'
    assert ps.read() == 'fff\nggg\n'

    # You can verify that the log has also been written.
    log_path = ps.get_log_file_path()
    with open(log_path) as f:
        txt = f.read()
    assert txt == 'fff\nggg\n'
def test_new_log_file():
    # Just a shorthand for persistent print.

    log_file_loc = new_log_file('dump/test_file')
    print 'eee'
    print 'fff'
    stop_capturing_print()

    local_log_loc = get_local_path(log_file_loc)
    with open(local_log_loc) as f:
        text = f.read()

    assert text == 'eee\nfff\n'
    os.remove(local_log_loc)
def test_persistent_print():

    test_log_path = capture_print()
    print 'aaa'
    print 'bbb'
    assert read_print()  == 'aaa\nbbb\n'
    stop_capturing_print()

    capture_print()
    assert read_print() == ''
    print 'ccc'
    print 'ddd'
    assert read_print()  == 'ccc\nddd\n'

    os.remove(get_local_path(test_log_path))
示例#21
0
def capture_print(log_file_path='logs/dump/%T-log.txt', print_to_console=True):
    """
    :param log_file_path: Path of file to print to, if (state and to_file).  If path does not start with a "/", it will
        be relative to the data directory.  You can use placeholders such as %T, %R, ... in the path name (see format
        filename)
    :param print_to_console:
    :param print_to_console: Also continue printing to console.
    :return: The absolute path to the log file.
    """
    local_log_file_path = get_local_path(log_file_path)
    logger = PrintAndStoreLogger(log_file_path=local_log_file_path,
                                 print_to_console=print_to_console)
    logger.__enter__()
    sys.stdout = logger
    sys.stderr = logger
    return local_log_file_path
示例#22
0
def get_all_record_ids(experiment_ids=None, filters=None):
    """
    :param experiment_ids: A list of experiment names
    :param filters: A list or regular expressions for matching experiments.
    :return: A list of experiment identifiers.
    """
    expdir = get_local_path('experiments')
    ids = [
        e for e in os.listdir(expdir) if os.path.isdir(os.path.join(expdir, e))
    ]
    ids = filter_experiment_ids(ids=ids, names=experiment_ids)
    if filters is not None:
        for expr in filters:
            ids = filter_experiment_ids(ids=ids, expr=expr)
    ids = sorted(ids)
    return ids
示例#23
0
def clear_experiment_records(ids=None):
    """
    Delete all experiments with ids in the list, or all experiments if ids is None.
    :param ids: A list of experiment ids, or None to remove all.
    """
    # Credit: http://stackoverflow.com/questions/185936/delete-folder-contents-in-python
    folder = get_local_path('experiments')

    if ids is None:
        ids = os.listdir(folder)

    for exp_id in ids:
        exp_path = os.path.join(folder, exp_id)
        try:
            if os.path.isfile(exp_path):
                os.unlink(exp_path)
            elif os.path.isdir(exp_path):
                shutil.rmtree(exp_path)
        except Exception as e:
            print(e)
def test_persistent_ordered_dict():

    file_path = get_local_path('tests/podtest.pkl')
    if os.path.exists(file_path):
        os.remove(file_path)

    with PersistentOrderedDict(file_path) as pod:
        assert pod.items() == []
        pod['a'] = [1, 2, 3]
        pod['b'] = [4, 5, 6]
        pod['c'] = [7, 8]
    pod['d'] = [9, 10]  # Should not be recorded

    with PersistentOrderedDict(file_path) as pod:
        assert pod.items() == [('a', [1, 2, 3]), ('b', [4, 5, 6]),
                               ('c', [7, 8])]
        pod['e'] = 11

    with PersistentOrderedDict(file_path) as pod:
        assert pod.items() == [('a', [1, 2, 3]), ('b', [4, 5, 6]),
                               ('c', [7, 8]), ('e', 11)]
示例#25
0
def get_file_and_cache(url,
                       data_transformation=None,
                       enable_cache_write=True,
                       enable_cache_read=True):

    _, ext = os.path.splitext(url)

    if enable_cache_read or enable_cache_write:
        hasher = hashlib.md5()
        hasher.update(url)
        code = hasher.hexdigest()
        local_cache_path = os.path.join(get_local_path('caches'), code + ext)

    if enable_cache_read and os.path.exists(local_cache_path):
        return local_cache_path
    elif enable_cache_write:
        full_path = get_file(relative_name=os.path.join('caches', code + ext),
                             url=url,
                             data_transformation=data_transformation)
        return full_path
    else:
        return get_temp_file(url, data_transformation=data_transformation)
示例#26
0
def send_port_if_running_and_join():
    port_file_path = get_local_path("tmp/plot_server/port.info",
                                    make_local_dir=True)
    if os.path.exists(port_file_path):
        with open(port_file_path, 'r') as f:
            port = pickle.load(f)
        print(port)
        print(
            "Your dbplot call is attached to an existing plotting server. \nAll stdout and stderr of this existing plotting server "
            "is forwarded to the process that first created this plotting server. \nIn the future we might try to hijack this and provide you "
            "with these data streams")
        print(
            "Use with care, this functionallity might have unexpected side issues"
        )
        try:
            while (True):
                time.sleep(20)
        except KeyboardInterrupt:
            print(" Redirected Server killed")
            sys.exit()
    else:
        with open(port_file_path, "w") as f:
            pass
示例#27
0
import logging
from artemis.fileman.local_dir import get_local_path, make_file_dir
from artemis.general.test_mode import is_test_mode
from functools import partial
import numpy as np
import pickle
import os
logging.basicConfig()
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)

__author__ = 'peter'

MEMO_WRITE_ENABLED = True
MEMO_READ_ENABLED = True
MEMO_DIR = get_local_path('memoize_to_disk')


def memoize_to_disk(fcn,
                    local_cache=False,
                    disable_on_tests=True,
                    use_cpickle=False):
    """
    Save (memoize) computed results to disk, so that the same function, called with the
    same arguments, does not need to be recomputed.  This is useful if you have a long-running
    function that is often being given the same arguments.  Note: this does NOT check for the state
    of Global variables/time/whatever else the function may use, so you need to make sure your
    function is truly a function in that outputs only depend on inputs.  Otherwise, this will
    give you misleading results.

    Usage:
示例#28
0
def get_archive(relative_path,
                url,
                force_extract=False,
                archive_type=None,
                force_download=False):
    """
    Download a compressed archive and extract it into a folder.

    :param relative_path: Local name for the extracted folder.  (Zip file will be named this with the appropriate zip extension)
    :param url: Url of the archive to download
    :param force_extract: Force the zip file to re-extract (rather than just reusing the extracted folder)
    :return: The full path to the extracted folder on your system.
    """

    local_folder_path = get_local_path(relative_path)

    assert archive_type in ('.tar.gz', '.zip', None)

    if force_download:
        shutil.rmtree(local_folder_path)

    if not os.path.exists(
            local_folder_path
    ) or force_download:  # If the folder does not exist, download zip and extract.
        # (We also check force download here to avoid a race condition)
        response = urllib2.urlopen(url)

        # Need to infer
        if archive_type is None:
            if url.endswith('.tar.gz'):
                archive_type = '.tar.gz'
            elif url.endswith('.zip'):
                archive_type = '.zip'
            else:
                info = response.info()
                try:
                    header = next(x for x in info.headers
                                  if x.startswith('Content-Disposition'))
                    original_file_name = next(
                        x for x in header.split(';')
                        if x.startswith('filename')).split('=')[-1].lstrip(
                            '"\'').rstrip('"\'')
                    archive_type = '.tar.gz' if original_file_name.endswith('.tar.gz') else '.zip' if original_file_name.endswith('.zip') else \
                        bad_value(original_file_name, 'Filename "%s" does not end with a familiar zip extension like .zip or .tar.gz' % (original_file_name, ))
                except StopIteration:
                    raise Exception(
                        "Could not infer archive type from user argument, url-name, or file-header.  Please specify archive type as either '.zip' or '.tar.gz'."
                    )
        print 'Downloading archive from url: "%s"...' % (url, )
        data = response.read()
        print '...Done.'

        local_zip_path = local_folder_path + archive_type
        make_file_dir(local_zip_path)
        with open(local_zip_path, 'w') as f:
            f.write(data)

        force_extract = True

    if force_extract:
        if archive_type == '.tar.gz':
            with tarfile.open(local_zip_path) as f:
                f.extractall(local_folder_path)
        elif archive_type == '.zip':
            with ZipFile(local_zip_path) as f:
                f.extractall(local_folder_path)
        else:
            raise Exception()

    return local_folder_path
示例#29
0
def demo_optimize_conv_scales(n_epochs=5,
                              comp_weight=1e-11,
                              learning_rate=0.1,
                              error_loss='KL',
                              use_softmax=True,
                              optimizer='sgd',
                              shuffle_training=False):
    """
    Run the scale optimization routine on a convnet.  
    :param n_epochs:
    :param comp_weight:
    :param learning_rate:
    :param error_loss:
    :param use_softmax:
    :param optimizer:
    :param shuffle_training:
    :return:
    """
    if error_loss == 'KL' and not use_softmax:
        raise Exception(
            "It's very strange that you want to use a KL divergence on something other than a softmax error.  I assume you've made a mistake."
        )

    training_videos, training_vgg_inputs = get_vgg_video_splice(
        ['ILSVRC2015_train_00033010', 'ILSVRC2015_train_00336001'],
        shuffle=shuffle_training,
        shuffling_rng=1234)
    test_videos, test_vgg_inputs = get_vgg_video_splice(
        ['ILSVRC2015_train_00033009', 'ILSVRC2015_train_00033007'])

    set_dbplot_figure_size(12, 6)

    n_frames_to_show = 10
    display_frames = np.arange(
        len(test_videos) / n_frames_to_show / 2, len(test_videos),
        len(test_videos) / n_frames_to_show)
    ax1 = dbplot(np.concatenate(test_videos[display_frames], axis=1),
                 "Test Videos",
                 title='',
                 plot_type='pic')
    plt.subplots_adjust(wspace=0, hspace=.05)
    ax1.set_xticks(224 * np.arange(len(display_frames) / 2) * 2 + 224 / 2)
    ax1.tick_params(labelbottom='on')

    layers = get_vgg_layer_specifiers(
        up_to_layer='prob' if use_softmax else 'fc8')

    # Setup the true VGGnet and get the outputs
    f_true = ConvNet.from_init(layers, input_shape=(3, 224, 224)).compile()
    true_test_out = flatten2(
        np.concatenate([
            f_true(frame_positions[None])
            for frame_positions in test_vgg_inputs
        ]))
    top5_true_guesses = argtopk(true_test_out, 5)
    true_guesses = np.argmax(true_test_out, axis=1)
    true_labels = [
        get_vgg_label_at(g, short=True)
        for g in true_guesses[display_frames[::2]]
    ]
    full_convnet_cost = np.array([
        get_full_convnet_computational_cost(layer_specs=layers,
                                            input_shape=(3, 224, 224))
    ] * len(test_videos))

    # Setup the approximate networks
    slrc_net = ScaleLearningRoundingConvnet.from_convnet_specs(
        layers,
        optimizer=get_named_optimizer(optimizer, learning_rate=learning_rate),
        corruption_type='rand',
        rng=1234)
    f_train_slrc = slrc_net.train_scales.partial(
        comp_weight=comp_weight, error_loss=error_loss).compile()
    f_get_scales = slrc_net.get_scales.compile()
    round_fp = RoundConvNetForwardPass(layers)
    sigmadelta_fp = SigmaDeltaConvNetForwardPass(layers,
                                                 input_shape=(3, 224, 224))

    p = ProgressIndicator(n_epochs * len(training_videos))

    output_dir = make_dir(get_local_path('output/%T-convnet-spikes'))

    for input_minibatch, minibatch_info in minibatch_iterate_info(
            training_vgg_inputs,
            n_epochs=n_epochs,
            minibatch_size=1,
            test_epochs=np.arange(0, n_epochs, 0.1)):

        if minibatch_info.test_now:
            with EZProfiler('test'):
                current_scales = f_get_scales()
                round_cost, round_out = round_fp.get_cost_and_output(
                    test_vgg_inputs, scales=current_scales)
                sd_cost, sd_out = sigmadelta_fp.get_cost_and_output(
                    test_vgg_inputs, scales=current_scales)
                round_guesses, round_top1_correct, round_top5_correct = get_and_report_scores(
                    round_cost,
                    round_out,
                    name='Round',
                    true_top_1=true_guesses,
                    true_top_k=top5_true_guesses)
                sd_guesses, sd_top1_correct, sd_top5_correct = get_and_report_scores(
                    sd_cost,
                    sd_out,
                    name='SigmaDelta',
                    true_top_1=true_guesses,
                    true_top_k=top5_true_guesses)

                round_labels = [
                    get_vgg_label_at(g, short=True)
                    for g in round_guesses[display_frames[::2]]
                ]

                ax1.set_xticklabels([
                    '{}\n{}'.format(tg, rg)
                    for tg, rg in izip_equal(true_labels, round_labels)
                ])

                ax = dbplot(
                    np.array([
                        round_cost / 1e9, sd_cost / 1e9,
                        full_convnet_cost / 1e9
                    ]).T,
                    'Computation',
                    plot_type='thick-line',
                    ylabel='GOps',
                    title='',
                    legend=['Round', '$\Sigma\Delta$', 'Original'],
                )
                ax.set_xticklabels([])
                plt.grid()
                dbplot(
                    100 * np.array(
                        [cummean(sd_top1_correct),
                         cummean(sd_top5_correct)]).T,
                    "Score",
                    plot_type=lambda: LinePlot(
                        y_bounds=(0, 100),
                        plot_kwargs=[
                            dict(linewidth=3, color='k'),
                            dict(linewidth=3, color='k', linestyle=':')
                        ]),
                    title='',
                    legend=[
                        'Round/$\Sigma\Delta$ Top-1',
                        'Round/$\Sigma\Delta$ Top-5'
                    ],
                    ylabel='Cumulative\nPercent Accuracy',
                    xlabel='Frame #',
                    layout='v',
                )
                plt.grid()
            plt.savefig(
                os.path.join(output_dir,
                             'epoch-%.3g.pdf' % (minibatch_info.epoch, )))
        f_train_slrc(input_minibatch)
        p()
        print "Epoch {:3.2f}: Scales: {}".format(
            minibatch_info.epoch, ['%.3g' % float(s) for s in f_get_scales()])

    results = dict(current_scales=current_scales,
                   round_cost=round_cost,
                   round_out=round_out,
                   sd_cost=sd_cost,
                   sd_out=sd_out,
                   round_guesses=round_guesses,
                   round_top1_correct=round_top1_correct,
                   round_top5_correct=round_top5_correct,
                   sd_guesses=sd_guesses,
                   sd_top1_correct=sd_top1_correct,
                   sd_top5_correct=sd_top5_correct)

    dbplot_hang()
    return results
示例#30
0
def remove_port_file():
    print("Removing port file")
    port_file_path = get_local_path("tmp/plot_server/port.info",
                                    make_local_dir=True)
    if os.path.exists(port_file_path):
        os.remove(port_file_path)