def _read_section(self): """This is the function to read sections from the config file. :return: sect_list, sect_key_list sect_list: A list of ConfigSection(). sect_key_list: A list of names in sect_list. """ sect_name = None sect_list = {} sect_key_list = [] single_section = {} single_section_key = [] with open(self.file_path, 'r') as f: lines = f.readlines() for line in lines: if line.startswith('[') and line.endswith(']\n'): if sect_name is None: pass else: sect_list[sect_name] = single_section, single_section_key single_section = {} single_section_key = [] sect_key_list.append(sect_name) sect_name = line[1:-2] continue if line.startswith('#'): single_section[line] = '#' single_section_key.append(line) continue if line.startswith('\n'): single_section_key.append('\n') continue if '=' not in line: log = create_logger(__name__, './config_saver.log') log.error("can NOT load config file [%s]" % self.file_path) raise RuntimeError("can NOT load config file {}".__format__( self.file_path)) key = line.split('=', maxsplit=1)[0].strip() value = line.split('=', maxsplit=1)[1].strip() + '\n' single_section[key] = value single_section_key.append(key) if sect_name is not None: sect_list[sect_name] = single_section, single_section_key sect_key_list.append(sect_name) return sect_list, sect_key_list
def save_config_file(self, section_name, section): """This is the function to be called to change the config file with a single section and its name. :param section_name: The name of section what needs to be changed and saved. :param section: The section with key and value what needs to be changed and saved. :return: """ section_file = self._get_section(section_name) if len(section_file.__dict__.keys()) == 0: # the section not in the file before # append this section to config file with open(self.file_path, 'a') as f: f.write('[' + section_name + ']\n') for k in section.__dict__.keys(): f.write(k + ' = ') if isinstance(section[k], str): f.write('\"' + str(section[k]) + '\"\n\n') else: f.write(str(section[k]) + '\n\n') else: # the section exists change_file = False for k in section.__dict__.keys(): if k not in section_file: # find a new key in this section change_file = True break if section_file[k] != section[k]: logger = create_logger(__name__, "./config_loader.log") logger.warning("section [%s] in config file [%s] has been changed" % ( section_name, self.file_path )) change_file = True break if not change_file: return sect_list, sect_key_list = self._read_section() if section_name not in sect_key_list: raise AttributeError() sect, sect_key = sect_list[section_name] for k in section.__dict__.keys(): if k not in sect_key: if sect_key[-1] != '\n': sect_key.append('\n') sect_key.append(k) sect[k] = str(section[k]) if isinstance(section[k], str): sect[k] = "\"" + sect[k] + "\"" sect[k] = sect[k] + "\n" sect_list[section_name] = sect, sect_key self._write_section(sect_list, sect_key_list)
import os import time from datetime import timedelta import torch from tensorboardX import SummaryWriter from fastNLP.core.action import RandomSampler from fastNLP.core.batch import Batch from fastNLP.core.loss import Loss from fastNLP.core.optimizer import Optimizer from fastNLP.core.tester import SeqLabelTester, ClassificationTester from fastNLP.saver.logger import create_logger from fastNLP.saver.model_saver import ModelSaver logger = create_logger(__name__, "./train_test.log") class BaseTrainer(object): """Operations of training a model, including data loading, gradient descent, and validation. """ def __init__(self, **kwargs): """ :param kwargs: dict of (key, value), or dict-like object. key is str. The base trainer requires the following keys: - epochs: int, the number of epochs in training - validate: bool, whether or not to validate on dev set - batch_size: int - pickle_path: str, the path to pickle files for pre-processing