def _save_parameters(args, suffix, epoch, force=False): global _save_parameter_info if suffix not in _save_parameter_info: _save_parameter_info[suffix] = {} _save_parameter_info[suffix]['epoch'] = 0 _save_parameter_info[suffix]['time'] = 0 current_time = time.time() timediff = current_time - _save_parameter_info[suffix]['time'] epochdiff = epoch - _save_parameter_info[suffix]['epoch'] globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix)) exists = glob.glob(globname) base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch)) base_candidate = callback.result_base(base, suffix, args.outdir) if base_candidate is None: if suffix is None or suffix == 'best': base = os.path.join(args.outdir, 'results') else: base = base_candidate filename = base + '.nnp' if force or (not os.path.exists(filename) and (timediff > 180.0 or epochdiff > 10)): # Remove existing nnp before saving new file. for exist in exists: os.unlink(exist) version_filename = base + '_version.txt' with open(version_filename, 'w') as file: file.write('{}\n'.format(nnp_version())) param_filename = base + '_param.protobuf' save_parameters(param_filename) with zipfile.ZipFile(filename, 'w') as nnp: nnp.write(version_filename, 'nnp_version.txt') nnp.write(_save_parameter_info['config'], os.path.basename(_save_parameter_info['config'])) nnp.write(param_filename, 'parameter.protobuf') os.unlink(version_filename) os.unlink(param_filename) _save_parameter_info[suffix]['epoch'] = epoch _save_parameter_info[suffix]['time'] = current_time callback.save_train_snapshot()
def _save_parameters(args, suffix, epoch, train_config, force=False): global _save_parameter_info if suffix not in _save_parameter_info: _save_parameter_info[suffix] = {} _save_parameter_info[suffix]['epoch'] = 0 _save_parameter_info[suffix]['time'] = 0 current_time = time.time() timediff = current_time - _save_parameter_info[suffix]['time'] epochdiff = epoch - _save_parameter_info[suffix]['epoch'] globname = os.path.join(args.outdir, 'results_{}_*.nnp'.format(suffix)) exists = glob.glob(globname) base = os.path.join(args.outdir, 'results_{}_{}'.format(suffix, epoch)) base_candidate = callback.result_base(base, suffix, args.outdir) if base_candidate is None: if suffix is None or suffix == 'best': base = os.path.join(args.outdir, 'results') else: base = base_candidate filename = base + '.nnp' if force or (not os.path.exists(filename) and (timediff > 180.0 or epochdiff > 10)): # Remove existing nnp before saving new file. for exist in exists: os.unlink(exist) version_filename = base + '_version.txt' with open(version_filename, 'w') as file: file.write('{}\n'.format(nnp_version())) param_filename = base + '_param.h5' save_parameters(param_filename) need_save_opti = train_config.optimizers and epoch % _OPTIMIZER_CHECKPOINT_INTERVAL == 0 if need_save_opti: opti_filenames = save_optimizer_states(base, '.h5', train_config) with zipfile.ZipFile(filename, 'w') as nnp: nnp.write(version_filename, 'nnp_version.txt') nnp.write(_save_parameter_info['config'], os.path.basename(_save_parameter_info['config'])) nnp.write(param_filename, 'parameter.h5') if need_save_opti: for f in opti_filenames: nnp.write(f, f[len(base) + 1:]) os.unlink(version_filename) os.unlink(param_filename) if need_save_opti: for f in opti_filenames: os.unlink(f) _save_parameter_info[suffix]['epoch'] = epoch _save_parameter_info[suffix]['time'] = current_time callback.save_train_snapshot()