def load(self):
     print_debug('Loading statistics from ' + self._statistics_path)
     with open(self._statistics_path, 'rb') as f:
         statistics = pickle.load(f)
     self._best_statistics = statistics['best_statistics']
     self._final_statistics = statistics['final_statistics']
     self._best_model_id = statistics['best_model_id']
예제 #2
0
def save_checkpoint(model,
                    optimizer=None,
                    model_name='model',
                    validation_id=None):
    """
    save checkpoint (optimizer and model)
    :param model_name:
    :param validation_id:
    :param model:
    :param optimizer:
    :return:
    """
    path = output_path(_checkpoint_path.format(model_name),
                       validation_id=validation_id,
                       have_validation=True)

    print_debug('Saving checkpoint: ' + path)

    model = model.module if type(model) is torch.nn.DataParallel else model

    checkpoint = {'model_state_dict': model.state_dict()}

    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    torch.save(checkpoint, path)
예제 #3
0
def get_index(index_path):
    """
    Load a label index
    :param index_path:
    :return:the index
    """
    global previous_index_path
    global previous_indexed_labels
    if 'previous_index_path' in globals(
    ) and index_path == previous_index_path:
        print_debug('Labels index in cache')
        return previous_indexed_labels

    # check if labels have been indexed
    if os.path.isfile(index_path):
        # if model is validation
        print_debug('Loading labels index ' + index_path)
        with open(index_path) as f:
            indexed_labels = json.load(f)
        indexed_labels = {int(k): int(v) for k, v in indexed_labels.items()}
        previous_index_path = index_path
        previous_indexed_labels = indexed_labels

    else:
        print_errors('index ' + index_path + ' does not exist...')
        indexed_labels = None
    return indexed_labels
예제 #4
0
def save_loss(losses, ylabel='Loss'):

    min_freq = min(losses.values(), key=lambda x: x[1])[1]
    if min_freq == 0:
        return
    plt('loss').title('Losses curve')
    plt('loss').xlabel('x' + str(min_freq) + ' batches')
    plt('loss').ylabel(ylabel)

    for k in losses:

        offset = losses[k][1] // min_freq - 1
        plt('loss').plot(
            # in order to align the multiple losses
            [
                i for i in range(
                    offset,
                    len(losses[k][0]) * (losses[k][1] // min_freq) +
                    offset, losses[k][1] // min_freq)
            ],
            losses[k][0],
            label=k)

        _json = json.dumps(losses[k][0])
        path = output_path('loss_{}.logs'.format(k))
        print_debug('Exporting loss at ' + path)
        f = open(path, "w")
        f.write(_json)
        f.close()
    plt('loss').legend()
    save_fig_direct_call(figure_name='loss')
 def save(self):
     statistic = {
         'best_statistics': self._best_statistics,
         'best_model_id': self._best_model_id,
         'final_statistics': self._final_statistics
     }
     print_debug('Saving statistics at ' + self._statistics_path)
     with open(self._statistics_path, 'wb') as f:
         pickle.dump(statistic, f)
예제 #6
0
def _load_checkpoint(model_name, path=None):
    if path is None:
        path = output_path(_checkpoint_path.format(model_name),
                           have_validation=True)

    global _checkpoint
    if not os.path.isfile(path):
        print_errors('{} does not exist'.format(path), do_exit=True)
    print_debug('Loading checkpoint from ' + path)
    _checkpoint[model_name] = torch.load(path)
예제 #7
0
def load_loss(name):
    path = output_path(name + '.logs')
    print_debug('Loading loss at ' + path)
    if os.path.exists(path):
        with open(path) as f:
            loss = json.load(f)
        return loss
    else:
        print_debug(path + ' does not exist...')
    return []
예제 #8
0
def save_reversed_index(path, index, column=0):
    """
    save the index on disk for future use
    :param path:
    :param index:
    :param column:
    :return:
    """
    print_debug('Saving index at ' + path)
    reversed_index = reverse_indexing(index, column)
    _json = json.dumps(reversed_index)
    f = open(path, "w")
    f.write(_json)
    f.close()
예제 #9
0
def configure_engine():

    import getpass
    import time

    from engine.hardware import set_devices
    from engine.parameters import special_parameters
    # from engine.tensorboard import initialize_tensorboard
    from engine.util.clean import clean
    from engine.logging.logs import print_h1, print_info, print_durations, print_info, print_errors, print_info
    from engine.util.console.time import get_start_datetime
    from engine.util.console.welcome import print_welcome_message, print_goodbye
    from engine.parameters.ds_argparse import get_argparse, check_general_config, process_other_options
    import sys

    import atexit

    import os

    from engine.parameters import hyper_parameters as hp

    args = get_argparse()

    # def warn(*a, **k):
    #     pass
    import warnings
    # warnings.warn = warn
    warnings.filterwarnings('default', category=DeprecationWarning)

    process_other_options(args.more)
    # general setup
    set_verbose(args.verbose)
    special_parameters.plt_style = args.style
    special_parameters.homex = args.homex

    special_parameters.setup_name = os.path.split(
        os.path.split(sys.argv[0])[0])[1]
    special_parameters.project_path = os.path.split(
        os.path.split(sys.argv[0])[0])[0]

    ask_default = check_general_config(args)

    special_parameters.configure(args)

    if special_parameters.machine == 'auto':
        detect_machine()

    special_parameters.interactive_cluster = check_interactive_cluster(
        special_parameters.machine)

    if ask_default:
        ask_general_config_default(args)

    special_parameters.root_path = os.path.abspath(os.curdir)

    configure_homex()

    if args.clean or args.show:
        clean(special_parameters.homex, args.output_name, disp_only=args.show)
        exit()
    if args.list_aliases:
        list_aliases()
        exit()

    # hardware
    set_devices(args.gpu)
    special_parameters.nb_workers = args.nb_workers
    # special_parameters.nb_nodes = args.nb_nodes

    special_parameters.first_epoch = args.epoch - 1  # to be user friendly, we start at 1
    special_parameters.validation_id = args.validation_id

    config_name = hp.check_config(args)
    hp.check_parameters(args)

    default_name = os.path.split(sys.argv[0])[-1].replace('.py', '')

    if args.output_name == '*':
        if config_name is None:
            special_parameters.output_name = default_name
        else:
            c_name = config_name.split('/')[-1]
            special_parameters.output_name = default_name + '_' + c_name
    else:
        special_parameters.output_name = args.output_name

    def exit_handler():
        if special_parameters.tensorboard_writer is not None:
            special_parameters.tensorboard_writer.close()

        print_durations(time.time() - start_dt.timestamp(),
                        text='Total duration')

        if args.serious:
            print_h1("Goodbye")
        else:
            print_goodbye()

    print_h1('Hello ' + getpass.getuser() + '!')

    if not args.serious and is_warning():
        print_welcome_message()

    start_dt = get_start_datetime()

    print_info('Starting datetime: ' + start_dt.strftime('%Y-%m-%d %H:%M:%S'))

    # configuring experiment
    special_parameters.load_model = (args.epoch != 1
                                     or (args.eval and not args.train)
                                     or (args.export and not args.train)
                                     or args.restart or args.load_model)

    if special_parameters.load_model:
        _, special_parameters.experiment_name = last_experiment(
            special_parameters.output_name)
        # special_parameters.output_name = name
        if special_parameters.experiment_name is None:
            print_errors('No previous experiment named ' +
                         special_parameters.output_name,
                         do_exit=True)

        if args.restart:
            special_parameters.first_epoch = load_last_epoch()
            special_parameters.restart_experiment = True
            print_debug('Restarting experiment at last epoch: {}'.format(
                special_parameters.first_epoch))
    else:
        special_parameters.experiment_name = special_parameters.output_name + '_' + start_dt.strftime(
            '%Y%m%d%H%M%S')

    if not is_info():
        print('Output directory: ' + output_directory() + '\n')

    # tensorboard
    # if special_parameters.tensorboard:
    # initialize_tensorboard()

    export_config()
    atexit.register(exit_handler)