Ejemplo n.º 1
0
    def __init__(self, opt, model_id, tokenizer_opt=None, load=False,
                 timeout=-1, on_timeout="to_cpu", model_root="./"):
        self.model_root = model_root
        self.opt = self.parse_opt(opt)
        if self.opt.n_best > 1:
            raise ValueError("Values of n_best > 1 are not supported")

        self.model_id = model_id
        self.tokenizer_opt = tokenizer_opt
        self.timeout = timeout
        self.on_timeout = on_timeout

        self.unload_timer = None
        self.user_opt = opt
        self.tokenizer = None

        if len(self.opt.log_file) > 0:
            log_file = os.path.join(model_root, self.opt.log_file)
        else:
            log_file = None
        self.logger = init_logger(log_file=log_file,
                                  log_file_level=self.opt.log_file_level)

        self.loading_lock = threading.Event()
        self.loading_lock.set()
        self.running_lock = threading.Semaphore(value=1)

        set_random_seed(self.opt.seed, self.opt.cuda)

        if load:
            self.load()
Ejemplo n.º 2
0
def batch_producer(generator_to_serve, queues, semaphore, opt):
    init_logger(opt.log_file)
    set_random_seed(opt.seed, False)

    # generator_to_serve = iter(generator_to_serve)

    def pred(x):
        """
        Filters batches that belong only
        to gpu_ranks of current node
        """
        for rank in opt.gpu_ranks:
            if x[0] % opt.world_size == rank:
                return True

    generator_to_serve = filter(pred, enumerate(generator_to_serve))

    def next_batch(device_id):
        new_batch = next(generator_to_serve)
        semaphore.acquire()
        return new_batch[1]

    b = next_batch(0)

    for device_id, q in cycle(enumerate(queues)):
        b.dataset = None
        if isinstance(b.src, tuple):
            b.src = tuple([_.to(torch.device(device_id)) for _ in b.src])
        else:
            b.src = b.src.to(torch.device(device_id))
        b.tgt = b.tgt.to(torch.device(device_id))
        b.indices = b.indices.to(torch.device(device_id))
        b.alignment = b.alignment.to(torch.device(device_id)) \
            if hasattr(b, 'alignment') else None
        b.src_map = b.src_map.to(torch.device(device_id)) \
            if hasattr(b, 'src_map') else None
        b.align = b.align.to(torch.device(device_id)) \
            if hasattr(b, 'align') else None

        # hack to dodge unpicklable `dict_keys`
        b.fields = list(b.fields)
        q.put(b)
        b = next_batch(device_id)
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  with_tree=opt.with_tree,
                                  tree_type=opt.tree_type,
                                  with_tree_as_graph=opt.as_graph,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)
    tree_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, align_reader,
                       tree_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader,
                           align_reader, tree_reader, opt)
Ejemplo n.º 4
0
def main():
    opt = parse_args()

    assert opt.max_shard_size == 0, \
        "-max_shard_size is deprecated. Please use \
        -shard_size (number of examples) instead."
    assert opt.shuffle == 0, \
        "-shuffle is not implemented. Please shuffle \
        your data before pre-processing."

    assert os.path.isfile(opt.train_src) and os.path.isfile(opt.train_tgt), \
        "Please check path of your train src and tgt files!"

    assert os.path.isfile(opt.valid_src) and os.path.isfile(opt.valid_tgt), \
        "Please check path of your valid src and tgt files!"

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
Ejemplo n.º 5
0
def main():
    opt = parse_args()

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    print(opt)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    # train_dataset_files = 'data/processed.train.pt'
    build_save_vocab(train_dataset_files, opt.data_type, fields, opt)
Ejemplo n.º 6
0
def prepare_translators(langspecf):
    global translatorbest, translatorbigram, langspec
    with open(os.path.join(dir_path, 'opt_data'), 'rb') as f:
        opt = pickle.load(f)

    if not langspec or langspec != langspecf:
        opt.models = [os.path.join(dir_path, 'model', langspecf['model'])]
        opt.n_best = 1
        ArgumentParser.validate_translate_opts(opt)
        logger = init_logger(opt.log_file)
        translatorbest = build_translator(opt, report_score=True)

        opt.models = [os.path.join(dir_path, 'model', langspecf['model'])]
        opt.n_best = 5
        opt.max_length = 2
        ArgumentParser.validate_translate_opts(opt)
        logger = init_logger(opt.log_file)
        translatorbigram = build_translator(opt, report_score=True)

        langspec = langspecf
Ejemplo n.º 7
0
def main(opt):
    logger = init_logger(opt.log_file)
    translator = build_translator(opt, report_score=True, logger=logger)
    translator.translate(src=opt.src,
                         src_title=opt.src_title,
                         tgt=opt.tgt,
                         src_dir=opt.src_dir,
                         batch_size=opt.batch_size,
                         attn_debug=opt.attn_debug)
    # add by wchen
    evaluate_func(opt)
Ejemplo n.º 8
0
def main():
    opt = parse_args()
    if (opt.max_shard_size > 0):
        raise AssertionError("-max_shard_size is deprecated, please use \
                             -shard_size (number of examples) instead.")

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, 0, 0)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(fields, opt)
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    if opt.constraint_file:
        tag_shards = split_corpus(opt.constraint_file,
                                  opt.shard_size,
                                  iter_func=constraint_iter_func,
                                  binary=False)

    translator = build_translator(opt, report_score=True, logger=logger)

    def create_src_shards(path, opt, binary=True):
        if opt.data_type == 'imgvec':
            assert opt.shard_size <= 0
            return [path]
        else:
            if opt.data_type == 'none':
                return [None] * 99999
            else:
                return split_corpus(path, opt.shard_size, binary=binary)

    src_shards = create_src_shards(opt.src, opt)
    if opt.agenda:
        agenda_shards = create_src_shards(opt.agenda, opt, False)

    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)

    if not opt.agenda:
        shards = zip(src_shards, tgt_shards)
    else:
        shards = zip(src_shards, agenda_shards, tgt_shards)

    for i, flat_shard in enumerate(shards):
        if not opt.agenda:
            src_shard, tgt_shard = flat_shard
            agenda_shard = None
        else:
            src_shard, agenda_shard, tgt_shard = flat_shard
        logger.info("Translating shard %d." % i)

        tag_shard = None
        if opt.constraint_file:
            tag_shard = next(tag_shards)

        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             agenda=agenda_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug,
                             tag_shard=tag_shard)
Ejemplo n.º 10
0
def multi_main(func, args):
    """ Spawns 1 process per GPU """
    init_logger()
    nb_gpu = args.world_size
    mp = torch.multiprocessing.get_context('spawn')

    # Create a thread to listen for errors in the child processes.
    error_queue = mp.SimpleQueue()
    error_handler = distributed.ErrorHandler(error_queue)

    # Train with multiprocessing.
    procs = []
    for i in range(nb_gpu):
        device_id = i
        proc_args = (func, args, device_id, error_queue)
        procs.append(mp.Process(target=run, args=proc_args, daemon=True))
        procs[i].start()
        logger.info(" Starting process pid: %d  " % procs[i].pid)
        error_handler.add_child(procs[i].pid)
    for p in procs:
        p.join()
Ejemplo n.º 11
0
def translate(opt): 
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    # shard_pairs = zip(src_shards, tgt_shards)
    # print("number of shards: ", len(src_shards), len(tgt_shards))

    # load emotions
    tgt_emotion_shards = [None]*100
    if opt.target_emotions_path != "":
        print("Loading target emotions...")
        tgt_emotions = read_emotion_file(opt.target_emotions_path)
        tgt_emotion_shards = split_emotions(tgt_emotions, opt.shard_size)
        # print("number of shards: ", len(tgt_emotion_shards))
    
    tgt_concept_embedding_shards = [None]*100
    if opt.target_concept_embedding != "":
        print("Loading target_concept_embedding...")
        tgt_concept_embedding = load_pickle(opt.target_concept_embedding)
        tgt_concept_embedding_shards = split_emotions(tgt_concept_embedding, opt.shard_size)
        # print("number of shards: ", len(tgt_concept_embedding_shards))
    
    tgt_concept_words_shards = [None]*100
    if opt.target_concept_words != "":
        print("Loading target_concept_words...")
        tgt_concept_words = load_pickle(opt.target_concept_words)
        # tgt_concept_words_shards = split_emotions(zip(tgt_concept_words), opt.shard_size)
        tgt_concept_words_shards = [tgt_concept_words]
        # print("number of shards: ", len(tgt_concept_words_shards))
    
    shard_pairs = zip(src_shards, tgt_shards, tgt_emotion_shards, tgt_concept_embedding_shards, tgt_concept_words_shards)

    for i, (src_shard, tgt_shard, tgt_emotion_shard, tgt_concept_embedding_shard, tgt_concept_words_shard) in enumerate(shard_pairs):
        # print(len(src_shard), len(tgt_shard), len(tgt_emotion_shard))
        logger.info("Translating shard %d." % i)
        translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            batch_type=opt.batch_type,
            attn_debug=opt.attn_debug,
            tgt_emotion_shard=tgt_emotion_shard,
            rerank=opt.rerank,
            emotion_lexicon=opt.emotion_lexicon,
            tgt_concept_embedding_shard=tgt_concept_embedding_shard,
            tgt_concept_words_shard=tgt_concept_words_shard
            )
Ejemplo n.º 12
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir)

    if 'n_latent' not in vars(opt):
        vars(opt)['n_latent'] = vars(opt)['n_translate_latent']
    logger = init_logger(opt.log_file)

    if 'use_segments' not in vars(opt):
        vars(opt)['use_segments'] = opt.n_translate_segments != 0
        vars(opt)['max_segments'] = opt.n_translate_segments

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    n_latent = opt.n_latent

    if n_latent > 1:
        for latent_idx in range(n_latent):
            output_path = opt.output_dir + '/output_%d' % (latent_idx)
            out_file = codecs.open(output_path, 'w+', 'utf-8')
            translator.out_file = out_file

            for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
                logger.info("Translating shard %d." % i)
                translator.translate(src=src_shard,
                                     tgt=tgt_shard,
                                     src_dir=opt.src_dir,
                                     batch_size=opt.batch_size,
                                     attn_debug=opt.attn_debug,
                                     latent_idx=latent_idx)
            src_shards = split_corpus(opt.src, opt.shard_size)
            tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
                if opt.tgt is not None else repeat(None)
            shard_pairs = zip(src_shards, tgt_shards)
    else:
        output_path = opt.output_dir + '/output'
        out_file = codecs.open(output_path, 'w+', 'utf-8')
        translator.out_file = out_file

        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 attn_debug=opt.attn_debug)
Ejemplo n.º 13
0
    def __init__(
        self,
        opt,
        model_id,
        tokenizer_opt=None,
        load=False,
        timeout=-1,
        on_timeout="to_cpu",
        model_root="./",
    ):
        """
        Args:
            opt: (dict) options for the Translator
            model_id: (int) model id
            tokenizer_opt: (dict) options for the tokenizer or None
            load: (bool) whether to load the model during __init__
            timeout: (int) seconds before running `do_timeout`
                     Negative values means no timeout
            on_timeout: (str) in ["to_cpu", "unload"] set what to do on
                        timeout (see function `do_timeout`)
            model_root: (str) path to the model directory
                        it must contain de model and tokenizer file

        """
        self.model_root = model_root
        self.opt = self.parse_opt(opt)
        if self.opt.n_best > 1:
            raise ValueError("Values of n_best > 1 are not supported")

        self.model_id = model_id
        self.tokenizer_opt = tokenizer_opt
        self.timeout = timeout
        self.on_timeout = on_timeout

        self.unload_timer = None
        self.user_opt = opt
        self.tokenizer = None

        if len(self.opt.log_file) > 0:
            log_file = os.path.join(model_root, self.opt.log_file)
        else:
            log_file = None
        self.logger = init_logger(
            log_file=log_file, log_file_level=self.opt.log_file_level
        )

        self.loading_lock = threading.Event()
        self.loading_lock.set()
        self.running_lock = threading.Semaphore(value=1)

        if load:
            self.load()
Ejemplo n.º 14
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    if not(opt.overwrite):
        check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc,
        edges_vocab=opt.edges_vocab)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    grh_reader = inputters.str2reader["grh"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset(
        'train', fields, src_reader, tgt_reader, grh_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, grh_reader, opt)
Ejemplo n.º 15
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset(
        'train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        val_dataset_files = \
            build_save_dataset('valid', fields, src_reader, tgt_reader, opt)
        # TOM: Create a vocab out of both training and validation tokens
        train_dataset_files += val_dataset_files

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
Ejemplo n.º 16
0
def main(args):
    current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
    args.log_file = '%s_%s_%s' % (args.log_file, args.mode, current_time)
    init_logger(args.log_file)
    logger.info(args)

    args.gpu_ranks = [int(i) for i in args.gpu_ranks.split(',')]
    os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    device_id = 0 if device == "cuda" else -1

    if (args.world_size > 1):
        distributed.multi_main(train, args)
    elif (args.mode == 'train'):
        train(args, device_id)
    #elif (args.mode == 'validate'):
    elif (args.mode == 'test'):
        cp = args.test_from
        try:
            step = int(cp.split('.')[-2].split('_')[-1])
        except:
            step = 0
        evaluate(args, device_id, cp, step, mode='test')
Ejemplo n.º 17
0
def main():
    opt = parse_args()
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    valid_dataset_files = build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files + valid_dataset_files, fields, opt)
Ejemplo n.º 18
0
def main():
    opt = parse_args()
    init_logger(opt.log_file)
    logger.info("Extracting features...")


    src_nfeats = inputters.get_num_features(
        opt.data_type, opt.train_src, 'src')
    tgt_nfeats = inputters.get_num_features(
        opt.data_type, opt.train_tgt, 'tgt')
    ans_nfeats = inputters.get_num_features(
        opt.data_type, opt.train_ans, "ans")
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info(" * number of answer features: %d." % ans_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats, ans_nfeats)

    logger.info("fields src")
    logger.info(fields['src'].__dict__)
    logger.info(fields['tgt'].__dict__)
    logger.info(fields['src_map'].__dict__)
    logger.info(fields['ans'].__dict__)
    logger.info(fields['indices'].__dict__)
    logger.info(fields['alignment'].__dict__)


    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)
    logger.info(train_dataset_files)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)
Ejemplo n.º 19
0
def main():
    opt = parse_args()

    if (opt.max_shard_size > 0):
        raise AssertionError("-max_shard_size is deprecated, please use \
                             -shard_size (number of examples) instead.")
    if (opt.shuffle > 0):
        raise AssertionError("-shuffle is not implemented, please make sure \
                             you shuffle your data before pre-processing.")

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    fields0 = _get_fields(opt.data_type, opt.train_src, opt.train_tgt)
    fields1 = _get_fields(opt.data_type, opt.train_src1, opt.train_tgt1)

    #src_nfeats = inputters.get_num_features(
    #    opt.data_type, opt.train_src, 'src')
    #tgt_nfeats = inputters.get_num_features(
    #    opt.data_type, opt.train_tgt, 'tgt')
    #logger.info(" * number of source features: %d." % src_nfeats)
    #logger.info(" * number of target features: %d." % tgt_nfeats)

    #logger.info("Building `Fields` object...")
    #fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files0 = build_save_dataset('train0', fields0, opt)
    train_dataset_files1 = build_save_dataset('train1', fields1, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid0', fields0, opt)
    build_save_dataset('valid1', fields1, opt)

    logger.info("Building & saving vocabulary...")

    build_save_vocab(train_dataset_files0 + train_dataset_files1, fields0, opt)
Ejemplo n.º 20
0
def build_vocab_main(opts):
    """Apply transforms to samples of specified data and build vocab from it.

    Transforms that need vocab will be disabled in this.
    Built vocab is saved in plain text format as following and can be pass as
    `-src_vocab` (and `-tgt_vocab`) when training:
    ```
    <tok_0>\t<count_0>
    <tok_1>\t<count_1>
    ```
    """

    ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
    assert opts.n_sample == -1 or opts.n_sample > 1, \
        f"Illegal argument n_sample={opts.n_sample}."

    logger = init_logger()
    set_random_seed(opts.seed, False)
    transforms_cls = get_transforms_cls(opts._all_transform)
    fields = None

    transforms = make_transforms(opts, transforms_cls, fields)

    logger.info(f"Counter vocab from {opts.n_sample} samples.")
    src_counter, tgt_counter, src_feats_counter = build_vocab(
        opts, transforms, n_sample=opts.n_sample)

    logger.info(f"Counters src:{len(src_counter)}")
    logger.info(f"Counters tgt:{len(tgt_counter)}")
    for feat_name, feat_counter in src_feats_counter.items():
        logger.info(f"Counters {feat_name}:{len(feat_counter)}")

    def save_counter(counter, save_path):
        check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
        with open(save_path, "w", encoding="utf8") as fo:
            for tok, count in counter.most_common():
                fo.write(tok + "\t" + str(count) + "\n")

    if opts.share_vocab:
        src_counter += tgt_counter
        tgt_counter = src_counter
        logger.info(f"Counters after share:{len(src_counter)}")
        save_counter(src_counter, opts.src_vocab)
    else:
        save_counter(src_counter, opts.src_vocab)
        save_counter(tgt_counter, opts.tgt_vocab)

    for k, v in src_feats_counter.items():
        save_counter(v, opts.src_feats_vocab[k])
Ejemplo n.º 21
0
def main(model = "transformer", dataset = "toy-ende"):    
    init_logger()
    is_cuda = cuda.is_available()
    set_random_seed(1111, is_cuda)

    data = preprocess.setup_dataset(dataset)
    vocab = preprocess.setup_vocab(data)

    if model == "transformer":
        Model, loss, opt = transformer.SimpleTransformer(vocab)
    elif model == "lstm":
        Model, loss, opt = lstm.BaseLSTMModel(vocab)

    train, validate = training.training_iterator(data, vocab)
    TrainingSession = training.training_session(Model, loss, opt)

    report = TrainingSession.train(
        train_iter=train, 
        valid_iter=validate, 
        **defaults.training)

    evaluate.evaluation(model, data, vocab)
    
    return 0
Ejemplo n.º 22
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset(
        'train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
Ejemplo n.º 23
0
def main(opt):
    _logger_path = "logs/{}-test.log".format(opt.models[0].split('/')[1])
    _output_path = "logs/{}-output.log".format(opt.models[0].split('/')[1])
    logger = init_logger(_logger_path)
    logger.info('input_weight: {}'.format(opt.input_weight))
    logger.info('last_weight: {}'.format(opt.last_weight))
    _out_file = codecs.open(_output_path, 'w+', 'utf-8')
    logger.info('Start testing.')
    translator = build_translator(opt,
                                  report_score=True,
                                  logger=logger,
                                  out_file=_out_file)
    translator.translate(src_path=opt.src,
                         tgt_path=opt.tgt,
                         batch_size=opt.batch_size,
                         attn_debug=opt.attn_debug)
Ejemplo n.º 24
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, logger=logger, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             batch_size=opt.batch_size,
                             batch_type=opt.batch_type,
                             attn_debug=opt.attn_debug,
                             align_debug=opt.align_debug)
Ejemplo n.º 25
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug)
Ejemplo n.º 26
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    if opt.constraint_file:
        tag_shards = split_corpus(opt.constraint_file,
                                  opt.shard_size,
                                  iter_func=constraint_iter_func,
                                  binary=False)

    with open("opt.pkl", 'wb') as f1:
        pickle.dump(opt, f1)
    with open("opt.pkl", 'rb') as f1:
        opt1 = pickle.load(f1)
    translator = build_translator(opt, report_score=True)

    if opt.data_type == 'imgvec':
        assert opt.shard_size <= 0
        src_shards = [opt.src]
    else:
        if opt.data_type == 'none':
            src_shards = [None] * 99999
        else:
            src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)

        tag_shard = None
        if opt.constraint_file:
            tag_shard = next(tag_shards)

        all_scores, all_predictions = translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug,
            tag_shard=tag_shard)
        with open("result_{}.pickle".format(i), 'wb') as f1:
            pickle.dump(all_predictions, f1)
Ejemplo n.º 27
0
    def __init__(self,
                 opt,
                 model_id,
                 preprocess_opt=None,
                 tokenizer_opt=None,
                 postprocess_opt=None,
                 load=False,
                 timeout=-1,
                 on_timeout="to_cpu",
                 model_root="./"):
        self.model_root = model_root
        self.opt = self.parse_opt(opt)
        self.custom_opt = custom_opt

        self.model_id = model_id
        self.preprocess_opt = preprocess_opt
        self.tokenizers_opt = tokenizer_opt
        self.postprocess_opt = postprocess_opt
        self.timeout = timeout
        self.on_timeout = on_timeout

        self.ct2_model = os.path.join(model_root, ct2_model) \
            if ct2_model is not None else None

        self.unload_timer = None
        self.user_opt = opt
        self.tokenizers = None

        if len(self.opt.log_file) > 0:
            log_file = os.path.join(model_root, self.opt.log_file)
        else:
            log_file = None
        self.logger = init_logger(log_file=log_file,
                                  log_file_level=self.opt.log_file_level)

        self.loading_lock = threading.Event()
        self.loading_lock.set()
        self.running_lock = threading.Semaphore(value=1)

        set_random_seed(self.opt.seed, self.opt.cuda)

        if load:
            self.load(preload=True)
            self.stop_unload_timer()
Ejemplo n.º 28
0
def build_vocab_main(opts):
    """Apply transforms to samples of specified data and build vocab from it.

    Transforms that need vocab will be disabled in this.
    Built vocab is saved in plain text format as following and can be pass as
    `-src_vocab` (and `-tgt_vocab`) when training:
    ```
    <tok_0>\t<count_0>
    <tok_1>\t<count_1>
    ```
    """

    ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
    assert opts.n_sample == -1 or opts.n_sample > 1, \
        f"Illegal argument n_sample={opts.n_sample}."

    logger = init_logger()
    set_random_seed(opts.seed, False)
    transforms_cls = get_transforms_cls(opts._all_transform)
    fields = None

    transforms = make_transforms(opts, transforms_cls, fields)

    logger.info(f"Counter vocab from {opts.n_sample} samples.")
    src_counter, tgt_counter = save_transformed_sample(opts,
                                                       transforms,
                                                       n_sample=opts.n_sample,
                                                       build_vocab=True)

    logger.info(f"Counters src:{len(src_counter)}")
    logger.info(f"Counters tgt:{len(tgt_counter)}")
    if opts.share_vocab:
        src_counter += tgt_counter
        tgt_counter = src_counter
        logger.info(f"Counters after share:{len(src_counter)}")

    def save_counter(counter, save_path):
        with open(save_path, "w") as fo:
            for tok, count in counter.most_common():
                fo.write(tok + "\t" + str(count) + "\n")

    save_counter(src_counter, opts.save_data + '.vocab.src')
    save_counter(tgt_counter, opts.save_data + '.vocab.tgt')
Ejemplo n.º 29
0
def nmt_filter_dataset(opt):

    opt.src = os.path.join(dataset_root_path, src_file)
    opt.tgt = os.path.join(dataset_root_path, tgt_file)
    opt.shard_size = 1

    opt.log_file = logging_file_path
    opt.models = [model_file_path]
    opt.n_best = 1
    opt.beam_size = 1
    opt.report_bleu = False
    opt.report_rouge = False

    logger = init_logger(opt.log_file)
    translator = build_translator(opt, report_score=True)

    src_file_path = os.path.join(dataset_root_path, src_file)
    tgt_file_path = os.path.join(dataset_root_path, tgt_file)

    src_shards = split_corpus(src_file_path, opt.shard_size)
    tgt_shards = split_corpus(tgt_file_path, opt.shard_size)
    shard_pairs = zip(src_shards, tgt_shards)

    pred_scores = []

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):

        start_time = time.time()
        shard_pred_scores, shard_pred_sentences = translator.translate(
            src=src_shard,
            tgt=tgt_shard,
            src_dir=opt.src_dir,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
        print("--- %s seconds ---" % (time.time() - start_time))

        pred_scores += [scores[0] for scores in shard_pred_scores]

    average_pred_score = torch.mean(torch.stack(pred_scores)).detach()

    return average_pred_score
Ejemplo n.º 30
0
    def __init__(self,
                 opt,
                 model_id,
                 preprocess_opt=None,
                 tokenizer_opt=None,
                 postprocess_opt=None,
                 load=False,
                 timeout=-1,
                 on_timeout="to_cpu",
                 model_root="./"):
        self.model_root = model_root
        self.opt = self.parse_opt(opt)
        if self.opt.n_best > 1:
            raise ValueError("Values of n_best > 1 are not supported")

        self.model_id = model_id
        self.preprocess_opt = preprocess_opt
        self.tokenizer_opt = tokenizer_opt
        self.postprocess_opt = postprocess_opt
        self.timeout = timeout
        self.on_timeout = on_timeout

        self.unload_timer = None
        self.user_opt = opt
        self.tokenizer = None

        if len(self.opt.log_file) > 0:
            log_file = os.path.join(model_root, self.opt.log_file)
        else:
            log_file = None
        self.logger = init_logger(log_file=log_file,
                                  log_file_level=self.opt.log_file_level)

        self.loading_lock = threading.Event()
        self.loading_lock.set()
        self.running_lock = threading.Semaphore(value=1)

        set_random_seed(self.opt.seed, self.opt.cuda)

        if load:
            self.load()
Ejemplo n.º 31
0
def main(opt):
    ArgumentParser.validate_translate_opts(opt)

    logger = init_logger(opt.log_file)
    abs_path = os.path.dirname(opt.src)
    src_mode = opt.data_mode
    candidates_logprobs: List[List[Tuple[List[str], float]]] = list()

    if "patype0" in opt.src_types:
        translator = MultiSourceAPTypeAppendedTranslator.build_translator(
            opt.src_types, opt, report_score=True)
    else:
        translator = MultiSourceAPTranslator.build_translator(
            opt.src_types, opt, report_score=True)
    raw_data_keys = ["src.{}".format(src_type)
                     for src_type in opt.src_types] + (["tgt"])
    raw_data_paths: Dict[str, str] = {
        k: "{0}/{1}.{2}.txt".format(abs_path, k, src_mode)
        for k in raw_data_keys
    }
    raw_data_shards: Dict[str, list] = {
        k: list(split_corpus(p, opt.shard_size))
        for k, p in raw_data_paths.items()
    }

    for i in range(len(list(raw_data_shards.values())[0])):
        logger.info("Translating shard %d." % i)
        _, _, candidates_logprobs_shard = translator.translate(
            {k: v[i]
             for k, v in raw_data_shards.items()},
            True,
            src_dir=None,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
        candidates_logprobs.extend(candidates_logprobs_shard)

    # Reformat candidates
    candidates_logprobs: List[List[Tuple[str, float]]] = [[
        ("".join(c), l) for c, l in cl
    ] for cl in candidates_logprobs]
    return candidates_logprobs
Ejemplo n.º 32
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    translator = build_translator(opt, report_score=True, logger=logger)
    translator.out_file = codecs.open(opt.output, 'w+', 'utf-8')
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             batch_type=opt.batch_type,
                             attn_debug=opt.attn_debug,
                             align_debug=opt.align_debug,
                             opt=opt)
Ejemplo n.º 33
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Ejemplo n.º 34
0
        pass
        if os.path.isdir(tmp_dir):
            shutil.rmtree(tmp_dir)


def rouge_results_to_str(results_dict):
    return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format(
        results_dict["rouge_1_f_score"] * 100,
        results_dict["rouge_2_f_score"] * 100,
        results_dict["rouge_3_f_score"] * 100,
        results_dict["rouge_l_f_score"] * 100,
        results_dict["rouge_su*_f_score"] * 100)


if __name__ == "__main__":
    init_logger('test_rouge.log')
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', type=str, default="candidate.txt",
                        help='candidate file')
    parser.add_argument('-r', type=str, default="reference.txt",
                        help='reference file')
    args = parser.parse_args()
    if args.c.upper() == "STDIN":
        candidates = sys.stdin
    else:
        candidates = codecs.open(args.c, encoding="utf-8")
    references = codecs.open(args.r, encoding="utf-8")

    results_dict = test_rouge(candidates, references)
    logger.info(rouge_results_to_str(results_dict))