示例#1
0
def base_init(new_args):
    """This function should be called before accessing any other
    function in this module. It initializes the `args` variable on
    which all the create_* factory functions rely on as configuration
    object, and it sets up global function pointers and variables for
    basic things like the indexing scheme, logging verbosity, etc.

    Args:
        new_args: Configuration object from the argument parser.
    """
    global args
    args = new_args
    # UTF-8 support
    if sys.version_info < (3, 0):
        sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
        sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
        sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
    else:
        logging.warn("SGNMT is tested with Python 2.7, but you are using "
                     "Python 3. Expect the unexpected or switch to 2.7.")
    # Set up logger
    logger = logging.getLogger(__name__)
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
    logging.getLogger().setLevel(logging.INFO)
    if args.verbosity == 'debug':
        logging.getLogger().setLevel(logging.DEBUG)
    elif args.verbosity == 'info':
        logging.getLogger().setLevel(logging.INFO)
    elif args.verbosity == 'warn':
        logging.getLogger().setLevel(logging.WARN)
    elif args.verbosity == 'error':
        logging.getLogger().setLevel(logging.ERROR)
    # Set reserved word IDs
    if args.indexing_scheme == 'blocks':
        utils.switch_to_blocks_indexing()
    elif args.indexing_scheme == 'tf':
        utils.switch_to_tf_indexing()
    elif args.indexing_scheme == 't2t':
        utils.switch_to_t2t_indexing()
    # Log summation (how to compute log(exp(l1)+exp(l2)) for log values l1,l2)
    if args.log_sum == 'tropical':
        utils.log_sum = utils.log_sum_tropical_semiring
    # Predictor combination schemes
    if args.combination_scheme == 'length_norm':
        core.breakdown2score_full = core.breakdown2score_length_norm
    if args.combination_scheme == 'bayesian_loglin':
        core.breakdown2score_full = core.breakdown2score_bayesian_loglin
    if args.combination_scheme == 'bayesian':
        core.breakdown2score_full = core.breakdown2score_bayesian
    ui.validate_args(args)
示例#2
0
def base_init(new_args):
    """This function should be called before accessing any other
    function in this module. It initializes the `args` variable on 
    which all the create_* factory functions rely on as configuration
    object, and it sets up global function pointers and variables for
    basic things like the indexing scheme, logging verbosity, etc.

    Args:
        new_args: Configuration object from the argument parser.
    """
    global args
    args = new_args
    # UTF-8 support
    if sys.version_info < (3, 0):
        sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
        sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
        sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
        logging.warn("SGNMT is tested with Python 3, but you are using "
                     "Python 2. Expect the unexpected or switch to >3.5.")
    # Set up logger
    logger = logging.getLogger(__name__)
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
    logging.getLogger().setLevel(logging.INFO)
    if args.verbosity == 'debug':
        logging.getLogger().setLevel(logging.DEBUG)
    elif args.verbosity == 'info':
        logging.getLogger().setLevel(logging.INFO)
    elif args.verbosity == 'warn':
        logging.getLogger().setLevel(logging.WARN)
    elif args.verbosity == 'error':
        logging.getLogger().setLevel(logging.ERROR)
    # Set reserved word IDs
    if args.indexing_scheme == 'fairseq':
        utils.switch_to_fairseq_indexing()
    elif args.indexing_scheme == 't2t':
        utils.switch_to_t2t_indexing()
    else:
        raise NotImplementedError("Indexing scheme not implemented")
    # Log summation (how to compute log(exp(l1)+exp(l2)) for log values l1,l2)
    if args.log_sum == 'tropical':
        utils.log_sum = utils.log_sum_tropical_semiring
    ui.validate_args(args)
    if args.run_diagnostics:
        ui.run_diagnostics()
        sys.exit()
示例#3
0
def base_init(new_args):
    """This function should be called before accessing any other
    function in this module. It initializes the `args` variable on 
    which all the create_* factory functions rely on as configuration
    object, and it sets up global function pointers and variables for
    basic things like the indexing scheme, logging verbosity, etc.

    Args:
        new_args: Configuration object from the argument parser.
    """
    global args
    args = new_args
    # UTF-8 support
    if sys.version_info < (3, 0):
        sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
        sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
        sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
    else:
        logging.warn("SGNMT is tested with Python 2.7, but you are using "
                     "Python 3. Expect the unexpected or switch to 2.7.")
    # Set up logger
    logger = logging.getLogger(__name__)
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
    logging.getLogger().setLevel(logging.INFO)
    if args.verbosity == 'debug':
        logging.getLogger().setLevel(logging.DEBUG)
    elif args.verbosity == 'info':
        logging.getLogger().setLevel(logging.INFO)
    elif args.verbosity == 'warn':
        logging.getLogger().setLevel(logging.WARN)
    elif args.verbosity == 'error':
        logging.getLogger().setLevel(logging.ERROR)
    # Set reserved word IDs
    if args.indexing_scheme == 'blocks':
        utils.switch_to_blocks_indexing()
    elif args.indexing_scheme == 'tf':
        utils.switch_to_tf_indexing()
    elif args.indexing_scheme == 't2t':
        utils.switch_to_t2t_indexing()
    # Log summation (how to compute log(exp(l1)+exp(l2)) for log values l1,l2)
    if args.log_sum == 'tropical':
        utils.log_sum = utils.log_sum_tropical_semiring
    ui.validate_args(args)
示例#4
0
文件: train.py 项目: Jack44Wang/sgnmt
        logging.info("Source 7 length %d" % len(linModel.all_src[6]))

        with tf.Session() as session:
            session.run(init)
            for i in range(config.n_epochs):
                loss = linModel.train_on_batch(session, targets_batch)
                logging.info("loss: %d" % loss)
                linModel.config.eps = 0.5 * linModel.config.eps
                linModel.prepareSGNMT(args)  # reset hypos
                linModel.cur_hypos = linModel.all_hypos[5:5 + linModel.config.
                                                        batch_size]
                targets_batch = linModel.all_trg[5:5 +
                                                 linModel.config.batch_size]


def create_training_session(config):
    """Creates a MonitoredTrainingSession for training"""
    return training.MonitoredTrainingSession(checkpoint_dir=config.output_path,
                                             save_checkpoint_secs=1200)


if __name__ == "__main__":
    # MAIN CODE STARTS HERE
    # Load configuration from command line arguments or configuration file
    args = get_args()
    validate_args(args)
    utils.switch_to_t2t_indexing()
    config = Config(args)

    do_multi_epoch_train(args, False)