示例#1
0
def init_logger(logger_name=__name__, output_path=None, level=logging.INFO):
    """
    Initializes a logger for console and file writing.

    :param logger_name: self-explanatory.
    :param output_path: directory or file path where the logs should be saved.
                        By default it will not store file logs.
    :param level: self-explanatory.
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(level)

    formatter = logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s")

    # adding console output
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    if output_path:
        if is_file_path(output_path):
            safe_mkfdir(output_path)
            if os.path.exists(output_path):
                os.remove(output_path)
        else:
            safe_mkdir(output_path)
            # using the default name of the logger
            default_file_name = "log_" + strftime("%b_%d_%H_%M_%S") + '.txt'
            output_path = os.path.join(output_path, default_file_name)
        file_handler = logging.FileHandler(output_path)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    return logger
示例#2
0
 def save(self, file_path, encoding='utf-8'):
     """Saves hyper-params object as a json file."""
     safe_mkfdir(file_path)
     hparams_to_save = self._get_simple_attrs()
     f = codecs.open(file_path, encoding=encoding, mode='w')
     json.dump(hparams_to_save, f, indent=2)
     logger.debug("Extracted the following hparams: '%s'."
                  "" % " ".join(hparams_to_save.keys()))
     logger.info("Saved hyper-parameters to '%s'." % file_path)    
示例#3
0
def delete_attr_from_params(input_fp, output_fp, attr_names, device='cpu'):
    """Removes a particular attrs from the dictionary of params, saves back."""
    model_params = T.load(input_fp, device)[MODEL_PARAMS]

    for attr_name in attr_names:
        if attr_name in model_params:
            del model_params[attr_name]

    # dumping to the disk
    safe_mkfdir(output_fp)
    T.save({MODEL_PARAMS: model_params}, f=output_fp)
示例#4
0
def merge_csv_files(input_fps, output_fp, sep="\t"):
    """
    Merges the csv files that have the same header into one file.
    Does not perform shuffling to avoid problems with the same group entries.
    """
    dfs = []
    for fp in input_fps:
        dfs.append(read_csv(fp, sep=sep, quoting=QUOTE_NONE, encoding='utf-8'))
    df = concat(dfs, axis=0, ignore_index=True, copy=True)
    safe_mkfdir(output_fp)
    df.to_csv(output_fp, sep=sep, index=False, encoding='utf-8',
              quoting=QUOTE_NONE)
def rename_attrs_in_params(input_fp,
                           output_fp,
                           old_attr_names,
                           new_attr_names,
                           device='cpu'):
    """Renames a model's parameters, saves them to an output file."""
    assert len(old_attr_names) == len(new_attr_names)
    model_params = T.load(input_fp, device)[MODEL_PARAMS]
    for old_name, new_name in zip(old_attr_names, new_attr_names):
        model_params[new_name] = model_params[old_name]
        del model_params[old_name]
    # dumping to the disk
    safe_mkfdir(output_fp)
    T.save({MODEL_PARAMS: model_params}, f=output_fp)
示例#6
0
    def write(self, file_path, sep=' ', encoding='utf-8'):
        """
        Writes the vocabulary to a plain text file where each line is of the
        form: {token}{sep}{count}. Default special symbols are not written.

        :param file_path: self-explanatory.
        :param sep: self-explanatory.
        :param encoding: self-explanatory.
        """
        safe_mkfdir(file_path)
        with codecs.open(file_path, 'w', encoding=encoding) as f:
            for symbol in self:
                token = symbol.token
                count = str(symbol.count)
                try:
                    str_entry = sep.join([token, count])
                    f.write(str_entry)
                    f.write("\n")
                except Exception:
                    logger.fatal(
                        "Below entry produced a fatal error in write().")
                    logger.fatal(symbol.token)
                    raise ValueError("Could not process a token.")
        logger.info("Vocabulary is written to: '%s'." % file_path)