예제 #1
0
파일: train.py 프로젝트: ishihara-y/nnabla
def _save_parameters(args, suffix, epoch, force=False):
    global _save_parameter_info

    if suffix not in _save_parameter_info:
        _save_parameter_info[suffix] = {}
        _save_parameter_info[suffix]['epoch'] = 0
        _save_parameter_info[suffix]['time'] = 0

    current_time = time.time()
    timediff = current_time - _save_parameter_info[suffix]['time']
    epochdiff = epoch - _save_parameter_info[suffix]['epoch']

    globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix))
    exists = glob.glob(globname)

    base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch))
    base_candidate = callback.result_base(base, suffix, args.outdir)
    if base_candidate is None:
        if suffix is None or suffix == 'best':
            base = os.path.join(args.outdir, 'results')
    else:
        base = base_candidate

    filename = base + '.nnp'

    if force or (not os.path.exists(filename) and
                 (timediff > 180.0 or epochdiff > 10)):

        # Remove existing nnp before saving new file.
        for exist in exists:
            os.unlink(exist)

        version_filename = base + '_version.txt'

        with open(version_filename, 'w') as file:
            file.write('{}\n'.format(nnp_version()))

        param_filename = base + '_param.protobuf'
        save_parameters(param_filename)

        with zipfile.ZipFile(filename, 'w') as nnp:
            nnp.write(version_filename, 'nnp_version.txt')
            nnp.write(_save_parameter_info['config'],
                      os.path.basename(_save_parameter_info['config']))
            nnp.write(param_filename, 'parameter.protobuf')

        os.unlink(version_filename)
        os.unlink(param_filename)

        _save_parameter_info[suffix]['epoch'] = epoch
        _save_parameter_info[suffix]['time'] = current_time

        callback.save_train_snapshot()
예제 #2
0
def _save_parameters(args, suffix, epoch, train_config, force=False):
    global _save_parameter_info

    if suffix not in _save_parameter_info:
        _save_parameter_info[suffix] = {}
        _save_parameter_info[suffix]['epoch'] = 0
        _save_parameter_info[suffix]['time'] = 0

    current_time = time.time()
    timediff = current_time - _save_parameter_info[suffix]['time']
    epochdiff = epoch - _save_parameter_info[suffix]['epoch']

    globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix))
    exists = glob.glob(globname)

    base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch))
    base_candidate = callback.result_base(base, suffix, args.outdir)
    if base_candidate is None:
        if suffix is None or suffix == 'best':
            base = os.path.join(args.outdir, 'results')
    else:
        base = base_candidate

    filename = base + '.nnp'

    if force or (not os.path.exists(filename) and
                 (timediff > 180.0 or epochdiff > 10)):

        # Remove existing nnp before saving new file.
        for exist in exists:
            os.unlink(exist)

        version_filename = base + '_version.txt'

        with open(version_filename, 'w') as file:
            file.write('{}\n'.format(nnp_version()))

        param_filename = base + '_param.h5'
        save_parameters(param_filename)

        need_save_opti = train_config.optimizers and epoch % _OPTIMIZER_CHECKPOINT_INTERVAL == 0
        if need_save_opti:
            opti_filenames = save_optimizer_states(base, '.h5', train_config)

        with zipfile.ZipFile(filename, 'w') as nnp:
            nnp.write(version_filename, 'nnp_version.txt')
            nnp.write(_save_parameter_info['config'],
                      os.path.basename(_save_parameter_info['config']))
            nnp.write(param_filename, 'parameter.h5')
            if need_save_opti:
                for f in opti_filenames:
                    nnp.write(f, f[len(base) + 1:])

        os.unlink(version_filename)
        os.unlink(param_filename)
        if need_save_opti:
            for f in opti_filenames:
                os.unlink(f)

        _save_parameter_info[suffix]['epoch'] = epoch
        _save_parameter_info[suffix]['time'] = current_time

        callback.save_train_snapshot()