Esempio n. 1
0
def _save_parameters(tmp_nnp_file, train_config, nntxt_str):
    base = os.path.dirname(tmp_nnp_file)
    base = os.path.join(base, 'results')

    version_filename = base + '_version.txt'

    with open(version_filename, 'w') as file:
        file.write('{}\n'.format(nnp_version()))

    # This is for testing, start=>
    nntxt_filename = base + '_network.nntxt'
    with open(nntxt_filename, "w") as file:
        file.write(nntxt_str)
    # <= End.

    param_filename = base + '_param.h5'
    nn.parameter.save_parameters(param_filename)

    opti_filenames = save_optimizer_states(base, '.h5', train_config)

    with zipfile.ZipFile(tmp_nnp_file, 'w') as nnp:
        nnp.write(version_filename, 'nnp_version.txt')
        nnp.write(nntxt_filename, "network.nntxt")
        nnp.write(param_filename, 'parameter.h5')
        for f in opti_filenames:
            nnp.write(f, f[len(base) + 1:])

    os.unlink(version_filename)
    os.unlink(param_filename)
    for f in opti_filenames:
        os.unlink(f)
    logger.info("{} is saved.".format(tmp_nnp_file))
Esempio n. 2
0
def _save_parameters(args, suffix, epoch, train_config, force=False):
    global _save_parameter_info

    if suffix not in _save_parameter_info:
        _save_parameter_info[suffix] = {}
        _save_parameter_info[suffix]['epoch'] = 0
        _save_parameter_info[suffix]['time'] = 0

    current_time = time.time()
    timediff = current_time - _save_parameter_info[suffix]['time']
    epochdiff = epoch - _save_parameter_info[suffix]['epoch']

    globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix))
    exists = glob.glob(globname)

    base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch))
    base_candidate = callback.result_base(base, suffix, args.outdir)
    if base_candidate is None:
        if suffix is None or suffix == 'best':
            base = os.path.join(args.outdir, 'results')
    else:
        base = base_candidate

    filename = base + '.nnp'

    if force or (not os.path.exists(filename) and
                 (timediff > 180.0 or epochdiff > 10)):

        # Remove existing nnp before saving new file.
        for exist in exists:
            os.unlink(exist)

        version_filename = base + '_version.txt'

        with open(version_filename, 'w') as file:
            file.write('{}\n'.format(nnp_version()))

        param_filename = base + '_param.h5'
        save_parameters(param_filename)

        need_save_opti = train_config.optimizers and epoch % _OPTIMIZER_CHECKPOINT_INTERVAL == 0
        if need_save_opti:
            opti_filenames = save_optimizer_states(base, '.h5', train_config)

        with zipfile.ZipFile(filename, 'w') as nnp:
            nnp.write(version_filename, 'nnp_version.txt')
            nnp.write(_save_parameter_info['config'],
                      os.path.basename(_save_parameter_info['config']))
            nnp.write(param_filename, 'parameter.h5')
            if need_save_opti:
                for f in opti_filenames:
                    nnp.write(f, f[len(base) + 1:])

        os.unlink(version_filename)
        os.unlink(param_filename)
        if need_save_opti:
            for f in opti_filenames:
                os.unlink(f)

        _save_parameter_info[suffix]['epoch'] = epoch
        _save_parameter_info[suffix]['time'] = current_time

        callback.save_train_snapshot()