예제 #1
0
 def test_method(self):
     opt = copy.deepcopy(self.opt)
     if param_setting:
         for param, setting in param_setting:
             setattr(opt, param, setting)
     ArgumentParser.update_model_opts(opt)
     getattr(self, methodname)(opt)
예제 #2
0
def load_test_model(opt, model_path=None):
    if model_path is None:
        model_path = opt.models[0]
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)

    model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    fields = checkpoint['vocab']

    model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint,
                             opt.gpu)
    if opt.fp32:
        model.float()
    model.eval()
    model.generator.eval()
    return fields, model, model_opt
예제 #3
0
def load_test_model(opt, model_path=None):
    if model_path is None:
        model_path = opt.models[0]
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)

    model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    vocab = checkpoint['vocab']
    if inputters.old_style_vocab(vocab):
        fields = inputters.load_old_vocab(
            vocab, opt.data_type)
    else:
        fields = vocab

    model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint)
    if opt.fp32:
        model.float()
    model.eval()
    model.generator.eval()
    return fields, model, model_opt
예제 #4
0
def load_test_model(opt, model_path=None):
    if model_path is None:
        model_path = opt.models[0]
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)

    model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    fields = checkpoint['vocab']

    model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint,
                             opt.gpu)
    if opt.fp32:
        model.float()
    elif opt.int8:
        if opt.gpu >= 0:
            raise ValueError(
                "Dynamic 8-bit quantization is not supported on GPU")
        torch.quantization.quantize_dynamic(model, inplace=True)
    model.eval()
    model.generator.eval()
    return fields, model, model_opt
예제 #5
0
def load_test_model(opt, model_path=None):
    if model_path is None:
        model_path = opt.models[0]
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)

    model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    vocab = checkpoint['vocab']
    if inputters.old_style_vocab(vocab):
        fields = inputters.load_old_vocab(
            vocab, opt.data_type, dynamic_dict=model_opt.copy_attn
        )
    else:
        fields = vocab

    model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint,
                             opt.gpu)
    if opt.fp32:
        model.float()
    model.eval()
    model.generator.eval()
    return fields, model, model_opt
예제 #6
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    nb_gpu = len(opt.gpu_ranks)


    #if not os.path.isdir(opt.data):
    #    os.makedirs(opt.data)
    #if not os.path.isdir(opt.save_model ):
    #    os.makedirs(opt.save_model)


    #gpuを複数指定した場合
    if opt.world_size > 1:
        mp = torch.multiprocessing.get_context('spawn')
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            procs.append(mp.Process(target=run, args=(
                opt, device_id, error_queue, ), daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        for p in procs:
            p.join()

    #gpu一つの場合
    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:   # case only CPU
        single_main(opt, -1)
예제 #7
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # @Memray, check the dir existence beforehand to avoid path conflicting errors,
    #   and set save_model, tensorboard_log_dir, wandb_log_dir if not exist
    train_single._check_save_model_path(opt)
    if not os.path.exists(opt.tensorboard_log_dir):
        os.makedirs(opt.tensorboard_log_dir)

    # Scan previous checkpoint to resume training
    latest_step = 0
    latest_ckpt = None
    for subdir, dirs, filenames in os.walk(opt.exp_dir):
        for filename in sorted(filenames):
            if not filename.endswith('.pt'):
                continue
            step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')])
            if step > latest_step:
                latest_ckpt = os.path.join(subdir, filename)
                latest_step = step
    # if not saved in the exp folder, check opt.save_model
    if latest_ckpt is None and opt.save_model is not None:
        save_model_dir = os.path.dirname(os.path.abspath(opt.save_model))
        model_prefix = opt.save_model[opt.save_model.rfind(os.path.sep) + 1:]
        for subdir, dirs, filenames in os.walk(save_model_dir):
            for filename in sorted(filenames):
                if not filename.endswith('.pt'):
                    continue
                if not filename.startswith(model_prefix):
                    continue
                step = int(filename[filename.rfind('_') +
                                    1:filename.rfind('.pt')])
                if step > latest_step:
                    latest_ckpt = os.path.join(subdir, filename)
                    latest_step = step
    if latest_ckpt is not None:
        logger.info("A previous checkpoint is found, train from it: %s" %
                    latest_ckpt)
        setattr(opt, 'train_from', latest_ckpt)
        setattr(opt, 'reset_optim', 'none')

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    elif opt.vocab and opt.vocab != 'none':
        # added by @memray for multiple datasets
        vocab = torch.load(opt.vocab)
        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            vocab = load_old_vocab(vocab,
                                   opt.model_type,
                                   dynamic_dict=opt.copy_attn)
    elif opt.encoder_type == 'pretrained':
        vocab = None
    else:
        vocab = None

    fields = vocab

    # @memray: a temporary workaround, as well as train_single.py line 78
    if fields and opt.data_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(use_vocab=False,
                                    dtype=torch.long,
                                    postprocessing=make_tgt,
                                    sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    # @memray reload fields for news dataset and pretrained models
    tokenizer = None
    if opt.pretrained_tokenizer is not None:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer,
                                              opt.cache_dir,
                                              opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(opt, tokenizer=tokenizer)
    # elif opt.data_type == 'keyphrase':
    #     fields = reload_keyphrase_fields(opt, tokenizer=tokenizer)

    if len(opt.data_ids) > 1:
        # added by @memray, for loading multiple datasets
        if opt.multi_dataset:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base,
                                            fields,
                                            opt,
                                            multi=True)
        else:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
예제 #8
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    #     import pdb
    #     _check_ = torch.load("/home/irteam/users/kaist/ginalee/clean_data/baselines/9-domain5-185pre_step_2500.pt")
    #     model_encoder = [i for i in _check_['model'].keys() if "encoder" in i.split(".")]
    #     encoder = {}
    #     pdb.set_trace()
    #     for i, param in enumerate(model_encoder):
    #         if i == 0:
    #             encoder['embeddings.word_embeddings.weight'] = _check_['model'][param]
    #             continue
    #         param_ = ".".join(param.split(".")[1:])
    # #         if param.split(".")[1] == 'encoder':
    # #             param_ = ".".join(param.split(".")[2:])
    # #         else:
    # #             param_ = ".".join(param.split(".")[1:])
    #         encoder[param_] = _check_['model'][param]
    #     pdb.set_trace()

    configure_process(opt, device_id)
    init_logger(opt.log_file)
    logger.info(opt)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)

        load_vocab = torch.load(opt.data + '.vocab.pt')
        vocab = checkpoint['vocab']
        load_vocab['src'].fields[0][1].vocab = vocab['src'].fields[0][1].vocab
        load_vocab['tgt'].fields[0][1].vocab = vocab['tgt'].fields[0][1].vocab
        vocab = load_vocab
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    if opt.pretrain_from:
        check = torch.load(opt.pretrain_from,
                           map_location=lambda storage, loc: storage)
        model.load_state_dict(check['model'], strict=False)
        model.load_state_dict(check['generator'], strict=False)
        if 'dom_classifier' in check:
            model.load_state_dict(check['dom_classifier'], strict=False)

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    translator = None
    if opt.domain_cls_enc == False:
        translator = train_build_translator(opt,
                                            model,
                                            model_opt,
                                            fields,
                                            report_score=True)

    trainer = build_trainer(translator,
                            opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #9
0
def main(opt, device_id):
    import pickle
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    import json
    train_iters = []
    with open(opt.data) as json_file:
        data = json.load(json_file)
        vocab = data["vocab"]
        vocab2 = vocab

    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt

        vocab = torch.load(vocab)
    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    for key, value in data.items():
        if key == ("vocab"):
            continue
        elif key.startswith("valid"):
            valid_iter = (key.split("valid-")[1].split("-"),
                          build_dataset_iter(value,
                                             fields,
                                             opt,
                                             is_train=False))
        else:
            train_iters.append(
                (key.split("-"), build_dataset_iter(value, fields, opt)))
    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)

    model.critic = critic()
    model.critic.to(model.device)

    if model.decoder2 is not None:
        model.critic2 = critic()
        model.critic2.to(model.device)
    else:
        model.critic2 = None

    model.critic3 = None

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    vocab = torch.load(vocab2)

    #print (valid_iter is None)
    #valid_iter = valid_iter(datas[0][0],valid_iter)
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iters,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps,
                  smooth=opt.smooth)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #10
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    aux_vocab = None
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
        if opt.crosslingual:
            aux_vocab = checkpoint['aux_vocab']
    elif opt.crosslingual:
        assert opt.crosslingual in ['old', 'lm']
        vocab = torch.load(opt.data + '.vocab.pt')
        aux_vocab = torch.load(opt.aux_train_data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    def get_fields(vocab):
        if old_style_vocab(vocab):
            return load_old_vocab(vocab,
                                  opt.model_type,
                                  dynamic_dict=opt.copy_attn)
        else:
            return vocab

    fields = get_fields(vocab)
    aux_fields = None
    if opt.crosslingual:
        aux_fields = get_fields(aux_vocab)

    if opt.crosslingual:
        if opt.crosslingual == 'old':
            aeq(len(opt.eat_formats), 3)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data',
                 Eat2PlainCrosslingualTask, 'crosslingual', opt.eat_format[2])
            ]
        else:
            aeq(len(opt.eat_formats), 4)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', fields, 'data', EatLMMonoTask, 'lm',
                 opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[2]),
                ('train', aux_fields, 'aux_train_data', EatLMCrosslingualTask,
                 'crosslingual', opt.eat_formats[3])
            ]
        train_iter = build_crosslingual_dataset_iter(fields_info, opt)
    elif len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    else:
        device_id = 0 if nb_gpu == 1 else -1
        # NOTE Only pass train_iter in my crosslingual mode.
        train_iter = train_iter if opt.crosslingual else None
        passed_fields = {
            'main': fields,
            'crosslingual': aux_fields
        } if opt.crosslingual else None
        single_main(opt,
                    device_id,
                    train_iter=train_iter,
                    passed_fields=passed_fields)
예제 #11
0
def train(opt):
    init_logger(opt.log_file)
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    checkpoint, fields, transforms_cls = _init_train(
        opt)  # Datasets and transformations (Both dicts)
    train_process = partial(single_main,
                            fields=fields,
                            transforms_cls=transforms_cls,
                            checkpoint=checkpoint)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:

        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=consumer,
                           args=(train_process, opt, device_id, error_queue, q,
                                 semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producers = []
        # This does not work if we merge with the first loop, not sure why
        for device_id in range(nb_gpu):
            # Get the iterator to generate from
            train_iter = _build_train_iter(opt,
                                           fields,
                                           transforms_cls,
                                           stride=nb_gpu,
                                           offset=device_id)
            producer = mp.Process(target=batch_producer,
                                  args=(
                                      train_iter,
                                      queues[device_id],
                                      semaphore,
                                      opt,
                                  ),
                                  daemon=True)
            producers.append(producer)
            producers[device_id].start()
            logger.info(" Starting producer process pid: {}  ".format(
                producers[device_id].pid))
            error_handler.add_child(producers[device_id].pid)

        for p in procs:
            p.join()
        # Once training is done, we can terminate the producers
        for p in producers:
            p.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        # TODO make possible for custom GPU id. Also replace assert at utils/parse.py line 275
        train_process(opt, device_id=opt.gpu_ranks[0])
    else:  # case only CPU
        train_process(opt, device_id=-1)
예제 #12
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.teacher_model_path:
        logger.info('Loading teacher model from {path}'.format(
            path=opt.teacher_model_path))
        teacher_model_ckpt = torch.load(
            opt.teacher_model_path, map_location=lambda storage, loc: storage)

        teacher_model_opt = ArgumentParser.ckpt_model_opts(
            teacher_model_ckpt['opt'])
        ArgumentParser.update_model_opts(teacher_model_opt)
        ArgumentParser.validate_model_opts(teacher_model_opt)
        logger.info('Loading vocab from checkpoint at {path}'.format(
            path=opt.teacher_model_path))
        teacher_vocab = teacher_model_ckpt['vocab']

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab
        teacher_fields = teacher_vocab if opt.teacher_model_path else None

    # patch for fields that may be missing in old data/model
    # patch_fields(opt, fields)

    # Report src and tgt vocab sizes, including for features
    report_vocab_size(fields)
    if teacher_fields is not None:
        report_vocab_size(teacher_fields)

    # Build model.
    fields_opt = {"original": fields, "teacher": teacher_fields}
    model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint)
    # model = build_model(model_opt, opt, fields, checkpoint)
    teacher_model = build_model(
        teacher_model_opt, teacher_model_opt, teacher_fields,
        teacher_model_ckpt) if opt.teacher_model_path else None

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    if teacher_model is not None:
        n_params, enc, dec = _tally_parameters(teacher_model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)
        _check_save_model_path(teacher_model_opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    # model_saver = build_model_saver(model_opt, opt, model, fields, optim)
    model_saver = custom_model_saver.build_model_saver(model_opt, opt, model,
                                                       fields_opt, optim)

    tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \
        else dict(fields)["tgt"].base_field
    sos_id = tgt_field.vocab.stoi[tgt_field.init_token]

    if teacher_model is not None and opt.word_sampling:
        sampler = Emulator(teacher_model,
                           teacher_fields,
                           device_id,
                           max_length=50,
                           random_sampling_topk=5)
    else:
        sampler = None

    if teacher_model is not None:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                teacher_fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)
    else:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  sos_id=sos_id,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
예제 #13
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        print("load weight success")
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)

    if old_style_vocab(vocab):
        print("old style vocab")
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        print("not old style")
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    # added and deleted by zhengquan
    # model = torch.nn.parallel.DistributedDataParallel(model,
    #                                                   device_ids=[opt.local_rank],
    #                                                   output_device=opt.local_rank)
    # added and deleted by zhengquan for the availability of cuda devices.
    # In the DistributedDataParallel doc, it says
    # "DistributedDataParallel with multi-device module only works "
    # "with CUDA devices, but module parameters locate in {}."
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter(
        "train", fields, opt)  #在build_dataset_iter()中会用opt中的dataset_paths来载入数据
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #14
0
    def train_single(self,
                     output_model_dir: Path,
                     opt,
                     device_id,
                     batch_queue=None,
                     semaphore=None):
        from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter
        from roosterize.ml.onmt.MultiSourceModelBuilder import MultiSourceModelBuilder
        from roosterize.ml.onmt.MultiSourceModelSaver import MultiSourceModelSaver
        from roosterize.ml.onmt.MultiSourceTrainer import MultiSourceTrainer
        from onmt.inputters.inputter import load_old_vocab, old_style_vocab
        from onmt.train_single import configure_process, _tally_parameters, _check_save_model_path
        from onmt.utils.optimizers import Optimizer
        from onmt.utils.parse import ArgumentParser

        configure_process(opt, device_id)
        assert len(opt.accum_count) == len(
            opt.accum_steps
        ), 'Number of accum_count values must match number of accum_steps'
        # Load checkpoint if we resume from a previous training.
        if opt.train_from:
            self.logger.info('Loading checkpoint from %s' % opt.train_from)
            checkpoint = torch.load(opt.train_from,
                                    map_location=lambda storage, loc: storage)
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
            self.logger.info('Loading vocab from checkpoint at %s.' %
                             opt.train_from)
            vocab = checkpoint['vocab']
        else:
            checkpoint = None
            model_opt = opt
            vocab = torch.load(opt.data + '.vocab.pt')
        # end if

        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type,
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab
        # end if

        # Report src and tgt vocab sizes, including for features
        data_keys = [
            f"src.{src_type}" for src_type in self.config.get_src_types()
        ] + ["tgt"]
        for side in data_keys:
            f = fields[side]
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(side, f)]
            # end try
            for sn, sf in f_iter:
                if sf.use_vocab:
                    self.logger.info(' * %s vocab size = %d' %
                                     (sn, len(sf.vocab)))
            # end for

        # Build model
        model = MultiSourceModelBuilder.build_model(
            self.config.get_src_types(), model_opt, opt, fields, checkpoint)
        n_params, enc, dec = _tally_parameters(model)
        self.logger.info('encoder: %d' % enc)
        self.logger.info('decoder: %d' % dec)
        self.logger.info('* number of parameters: %d' % n_params)
        _check_save_model_path(opt)

        # Build optimizer.
        optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

        # Build model saver
        model_saver = MultiSourceModelSaver.build_model_saver(
            self.config.get_src_types(), model_opt, opt, model, fields, optim)

        trainer = MultiSourceTrainer.build_trainer(self.config.get_src_types(),
                                                   opt,
                                                   device_id,
                                                   model,
                                                   fields,
                                                   optim,
                                                   model_saver=model_saver)

        if batch_queue is None:
            if len(opt.data_ids) > 1:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                # end for
                train_iter = MultiSourceInputter.build_dataset_iter_multiple(
                    self.config.get_src_types(), train_shards, fields, opt)
            else:
                if opt.data_ids[0] is not None:
                    shard_base = "train_" + opt.data_ids[0]
                else:
                    shard_base = "train"
                # end if
                train_iter = MultiSourceInputter.build_dataset_iter(
                    self.config.get_src_types(), shard_base, fields, opt)
            # end if
        else:
            assert semaphore is not None, "Using batch_queue requires semaphore as well"

            def _train_iter():
                while True:
                    batch = batch_queue.get()
                    semaphore.release()
                    yield batch
                # end while

            # end def

            train_iter = _train_iter()
        # end if

        valid_iter = MultiSourceInputter.build_dataset_iter(
            self.config.get_src_types(), "valid", fields, opt, is_train=False)

        if len(opt.gpu_ranks):
            self.logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
        else:
            self.logger.info('Starting training on CPU, could be very slow')
        # end if
        train_steps = opt.train_steps
        if opt.single_pass and train_steps > 0:
            self.logger.warning(
                "Option single_pass is enabled, ignoring train_steps.")
            train_steps = 0
        # end if

        trainer.train(train_iter,
                      train_steps,
                      save_checkpoint_steps=opt.save_checkpoint_steps,
                      valid_iter=valid_iter,
                      valid_steps=opt.valid_steps)
        time_begin = trainer.report_manager.start_time
        time_end = time.time()

        if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()

        # Dump train metrics
        train_history = trainer.report_manager.get_joint_history()
        train_metrics = {
            "time_begin": time_begin,
            "time_end": time_end,
            "time": time_end - time_begin,
            "train_history": train_history,
        }
        IOUtils.dump(output_model_dir / "train-metrics.json", train_metrics,
                     IOUtils.Format.jsonNoSort)

        # Get the best step, depending on the lowest val_xent (cross entropy)
        best_loss = min([th["val_xent"] for th in train_history])
        best_step = [
            th["step"] for th in train_history if th["val_xent"] == best_loss
        ][-1]  # Take the last if multiple
        IOUtils.dump(output_model_dir / "best-step.json", best_step,
                     IOUtils.Format.json)
        return
예제 #15
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        #vocab = torch.load(opt.data + '.vocab.pt')

    train_iters = OrderedDict()
    valid_iters = OrderedDict()

    encoders = OrderedDict()
    decoders = OrderedDict()

    generators = OrderedDict()
    src_vocabs = OrderedDict()
    tgt_vocabs = OrderedDict()
    Fields_dict = OrderedDict()

    # variables needed for sharing the same embedding matrix across encoders and decoders
    firstTime = True
    weightToShare = None

    # we share the word embedding space when source lang and target lang are the same
    mapLang2Emb = {}
    #for (src_tgt_lang), data_path in zip(opt.src_tgt, opt.data):
    for index in range(len(opt.src_tgt)):
        src_tgt_lang = opt.src_tgt[index]
        data_path = opt.data[index]
        local_enc_dec_opts = AttrDict({
            key: model_opt.__dict__[key]
            for key in model_opt.__dict__.keys()
        })
        local_enc_dec_opts.model_type = update_to_local_attr(
            model_opt.model_type, index)
        #local_enc_dec_opts.audio_enc_pooling = model_opt.audio_enc_pooling[index]
        local_enc_dec_opts.audio_enc_pooling = update_to_local_attr(
            model_opt.audio_enc_pooling, index)
        local_enc_dec_opts.enc_layers = update_to_local_attr(
            model_opt.enc_layers, index)
        local_enc_dec_opts.dec_layers = update_to_local_attr(
            model_opt.dec_layers, index)
        local_enc_dec_opts.rnn_type = update_to_local_attr(
            model_opt.rnn_type, index)
        local_enc_dec_opts.encoder_type = update_to_local_attr(
            model_opt.encoder_type, index)
        local_enc_dec_opts.batch_size = update_to_local_attr(
            model_opt.batch_size, index)
        local_enc_dec_opts.batch_type = update_to_local_attr(
            model_opt.batch_type, index)
        local_enc_dec_opts.normalization = update_to_local_attr(
            model_opt.normalization, index)
        #local_enc_dec_opts.dec_rnn_size = model_opt.dec_rnn_size[index]

        src_lang, tgt_lang = src_tgt_lang.split('-')

        vocab = torch.load(data_path + '.vocab.pt')

        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type[0],
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab

        # Report src and tgt vocab sizes, including for features
        for side in ['src', 'tgt']:
            f = fields[side]
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(side, f)]
            for sn, sf in f_iter:
                if sf.use_vocab:
                    logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

        # Build model.
        encoder, src_embeddings = build_embeddings_then_encoder(
            local_enc_dec_opts, fields)

        encoders[src_lang] = encoder

        decoder, generator, tgt_embeddings = build_decoder_and_generator(
            local_enc_dec_opts, fields)

        decoders[tgt_lang] = decoder

        # Share the embedding matrix across all the encoders and decoders - preprocess with share_vocab required.
        if model_opt.share_embeddings and firstTime:
            tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight
            weightToShare = src_embeddings.word_lut.weight
        if model_opt.share_embeddings and (not firstTime):
            tgt_embeddings.word_lut.weight = weightToShare
            src_embeddings.word_lut.weight = weightToShare
        firstTime = False

        #TEST
        #if src_lang in mapLang2Emb:
        if src_lang in mapLang2Emb and model_opt.model_type == "text":
            encoder.embeddings.word_lut.weight = mapLang2Emb.get(src_lang)
        #TEST
        #else:
        elif model_opt.model_type == "text":
            mapLang2Emb[src_lang] = src_embeddings.word_lut.weight
        if tgt_lang in mapLang2Emb:
            decoder.embeddings.word_lut.weight = mapLang2Emb.get(tgt_lang)
        else:
            mapLang2Emb[tgt_lang] = tgt_embeddings.word_lut.weight

        #TEST
        if model_opt.model_type == "text":
            src_vocabs[src_lang] = fields['src'].base_field.vocab
        tgt_vocabs[tgt_lang] = fields['tgt'].base_field.vocab

        generators[tgt_lang] = generator

        # add this dataset iterator to the training iterators
        train_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct(
            'train', fields, data_path, local_enc_dec_opts)
        # add this dataset iterator to the validation iterators
        valid_iters[(src_lang,
                     tgt_lang)] = build_dataset_iter_fct('valid',
                                                         fields,
                                                         data_path,
                                                         local_enc_dec_opts,
                                                         is_train=False)

        Fields_dict[src_tgt_lang] = fields

    # Build model.
    model = build_model(model_opt, opt, fields, encoders, decoders, generators,
                        src_vocabs, tgt_vocabs, checkpoint)

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, Fields_dict, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            generators,
                            tgt_vocabs,
                            model_saver=model_saver)

    # TODO: not implemented yet
    #train_iterables = []
    #if len(opt.data_ids) > 1:
    #    for train_id in opt.data_ids:
    #        shard_base = "train_" + train_id
    #        iterable = build_dataset_iter(shard_base, fields, opt, multi=True)
    #        train_iterables.append(iterable)
    #    train_iter = MultipleDatasetIterator(train_iterables, device_id, opt)
    #else:
    #    train_iter = build_dataset_iter("train", fields, opt)

    #valid_iter = build_dataset_iter(
    #    "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iters, train_steps, opt.save_checkpoint_steps,
                  valid_iters, opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #16
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    # configure_process(opt, device_id)
    # init_logger(opt.log_file)
    # assert len(opt.accum_count) == len(opt.accum_steps), \
    #     'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        # ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        # vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
    vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    rl_model = build_model(model_opt, opt, fields, checkpoint)

    _check_save_model_path(opt)

    # Build optimizer.
    # optim = torch.optim.Adam(rl_model.parameters())
    optim = Optimizer.from_opt(rl_model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, rl_model, optim)
    # model_saver = None

    # trainer = build_trainer(
    #     opt, device_id, model, fields, optim, model_saver=model_saver)
    build_rltor = build_rltor_enc  # if not opt.rl_step else build_rltor_dec
    rltor = build_rltor(opt, rl_model, optim, model_saver, report_score=False)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    if opt.infer:
        tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \
            if opt.tag_src is not None else repeat(None)
        shard_pairs = zip(src_shards, tag_src_shards)
        for i, (src_shard, tag_src_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            rltor.infer(src_shard,
                        tag_src_shard,
                        batch_size=opt.batch_size,
                        batch_type=opt.batch_type)
    else:
        valid_src_shards = split_corpus(opt.valid_src, opt.shard_size)
        valid_tgt_shards = split_corpus(opt.valid_tgt, opt.shard_size) \
            if opt.tgt is not None else repeat(None)

        tag_src_shards = split_corpus(opt.tag_src, opt.shard_size) \
            if opt.tag_src is not None else repeat(None)
        valid_tag_src_shards = split_corpus(opt.valid_tag_src, opt.shard_size) \
            if opt.valid_tag_src is not None else repeat(None)
        tag_tgt_shards = split_corpus(opt.tag_tgt, opt.shard_size) \
            if opt.tag_tgt is not None else repeat(None)
        valid_tag_tgt_shards = split_corpus(opt.valid_tag_tgt, opt.shard_size) \
            if opt.valid_tag_tgt is not None else repeat(None)

        shard_pairs = zip(src_shards, tgt_shards, tag_src_shards,
                          tag_tgt_shards, valid_src_shards, valid_tgt_shards,
                          valid_tag_src_shards, valid_tag_tgt_shards)

        for i, (train_src_shard, train_tgt_shard, train_tag_src_shard,
                train_tag_tgt_shard, valid_src_shard, valid_tgt_shard,
                valid_tag_src_shard,
                valid_tag_tgt_shard) in enumerate(shard_pairs):
            logger.info("Learning shard %d." % i)
            rltor.train(train_src_shard,
                        train_tgt_shard,
                        train_tag_src_shard,
                        train_tag_tgt_shard,
                        valid_src_shard,
                        valid_tgt_shard,
                        valid_tag_src_shard,
                        valid_tag_tgt_shard,
                        src_dir=opt.src_dir,
                        batch_size=opt.batch_size,
                        batch_type=opt.batch_type)
예제 #17
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #18
0
    def __init__(self, model_dir):

        # Model dir
        self._model_dir = os.path.abspath(model_dir)
        if not os.path.isdir(self._model_dir):
            msg = f"{model_dir} doesn't exists'"
            raise ValueError(msg)

        # Extended model
        self._extended_model = ExtendedModel(model_dir)

        # Config
        self._config = self._extended_model.config

        # Options
        self._opts = self._config.opts

        # Get the model options
        model_path = self._opts.models[0]
        checkpoint = torch.load(
            model_path, map_location=lambda storage, loc: storage
        )
        self._model_opts = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(self._model_opts)
        ArgumentParser.validate_model_opts(self._model_opts)

        # Extract vocabulary
        vocab = checkpoint["vocab"]
        if inputters.old_style_vocab(vocab):
            self._fields = inputters.load_old_vocab(
                vocab, "text", dynamic_dict=False
            )
        else:
            self._fields = vocab

        # Train_steps
        self._train_steps = self._model_opts.train_steps

        # Build openmmt model
        self._opennmt_model = build_base_model(
            self._model_opts,
            self._fields,
            use_gpu(self._opts),
            checkpoint,
            self._opts.gpu,
        )

        # Translator
        try:
            min_length = self._opts.min_length
        except:
            min_length = 0

        try:
            max_length = self._opts.max_length
        except:
            max_length = 100

        try:
            beam_size = self._opts.beam_size
        except:
            beam_size = 5

        try:
            replace_unk = self._opts.replace_unk
        except:
            replace_unk = 0

        self._translator = Translator(
            self._opennmt_model,
            self._fields,
            TextDataReader(),
            TextDataReader(),
            gpu=self._opts.gpu,
            min_length=min_length,
            max_length=max_length,
            beam_size=beam_size,
            replace_unk=replace_unk,
            copy_attn=self._model_opts.copy_attn,
            global_scorer=GNMTGlobalScorer(0.0, -0.0, "none", "none"),
            seed=self.SEED,
        )

        online_learning = self._config.online_learning
        if online_learning:
            # Optim
            optimizer_opt = type("", (), {})()
            optimizer_opt.optim = "sgd"
            optimizer_opt.learning_rate = self._opts.learning_rate
            optimizer_opt.train_from = ""
            optimizer_opt.adam_beta1 = 0
            optimizer_opt.adam_beta2 = 0
            optimizer_opt.model_dtype = "fp32"
            optimizer_opt.decay_method = "none"
            optimizer_opt.start_decay_steps = 100000
            optimizer_opt.learning_rate_decay = 1.0
            optimizer_opt.decay_steps = 100000
            optimizer_opt.max_grad_norm = 5
            self._optim = Optimizer.from_opt(
                self._opennmt_model, optimizer_opt, checkpoint=None
            )

            trainer_opt = type("", (), {})()
            trainer_opt.lambda_coverage = 0.0
            trainer_opt.copy_attn = False
            trainer_opt.label_smoothing = 0.0
            trainer_opt.truncated_decoder = 0
            trainer_opt.model_dtype = "fp32"
            trainer_opt.max_generator_batches = 32
            trainer_opt.normalization = "sents"
            trainer_opt.accum_count = [1]
            trainer_opt.accum_steps = [0]
            trainer_opt.world_size = 1
            trainer_opt.average_decay = 0
            trainer_opt.average_every = 1
            trainer_opt.dropout = 0
            trainer_opt.dropout_steps = (0,)
            trainer_opt.gpu_verbose_level = 0
            trainer_opt.early_stopping = 0
            trainer_opt.early_stopping_criteria = (None,)
            trainer_opt.tensorboard = False
            trainer_opt.report_every = 50
            trainer_opt.gpu_ranks = []
            if self._opts.gpu != -1:
                trainer_opt.gpu_ranks = [self._opts.gpu]

            self._trainer = build_trainer(
                trainer_opt,
                self._opts.gpu,
                self._opennmt_model,
                self._fields,
                self._optim,
            )
        else:
            self._trainer = None
예제 #19
0
def save_hidden_states(opt, args):
    OnmtArgumentParser.update_model_opts(opt)
    OnmtArgumentParser.validate_model_opts(opt)

    # load model
    model_path = os.path.join(opt.save_model,
                              args['onmt']['translate']['model'])
    checkpoint = torch.load(model_path,
                            map_location=lambda storage, loc: storage)
    model_opt = OnmtArgumentParser.ckpt_model_opts(checkpoint["opt"])
    OnmtArgumentParser.update_model_opts(model_opt)
    OnmtArgumentParser.validate_model_opts(model_opt)
    vocab = checkpoint['vocab']
    model = build_model(model_opt, opt, vocab, checkpoint)

    cache = []
    cache_idxs = [
    ]  # stores index into the training data and length of sentence
    for split in ('train', 'valid', 'test'):
        if split == 'test':
            test_file = args['data']['test_path']
            test_data = open(test_file, "r").readlines()
            test_src, test_tgt = zip(*[line.split("\t") for line in test_data])

            src_reader = inputters.str2reader["text"].from_opt(opt)
            tgt_reader = inputters.str2reader["text"].from_opt(opt)
            src_data = {"reader": src_reader, "data": test_src, "dir": None}
            tgt_data = {"reader": tgt_reader, "data": test_tgt, "dir": None}
            _readers, _data, _dir = inputters.Dataset.config([
                ('src', src_data), ('tgt', tgt_data)
            ])

            data = inputters.Dataset(vocab,
                                     readers=_readers,
                                     data=_data,
                                     dirs=_dir,
                                     sort_key=inputters.str2sortkey["text"],
                                     filter_pred=None)

            batch_iter = inputters.OrderedIterator(dataset=data,
                                                   device=args['device'],
                                                   batch_size=64,
                                                   batch_size_fn=None,
                                                   train=False,
                                                   sort=False,
                                                   sort_within_batch=True,
                                                   shuffle=False)
        else:
            train_dataset_paths = get_dataset_paths(opt,
                                                    split,
                                                    eos=args['lm']['use_eos'])

            batch_iter = DatasetLazyIter(
                train_dataset_paths,
                vocab,  # vocab
                64,  # batch size
                None,  # "batch_fn"
                1,  # "batch_size_multiple"
                args['device'],  # device
                True,  # is train
                8192,  # pool factor
                repeat=False,
                num_batches_multiple=1,
                yield_raw_example=False)

        tgt_field = vocab["tgt"].base_field
        tgt_pad_idx = tgt_field.vocab.stoi[tgt_field.pad_token]
        for batch_i, batch in tqdm(enumerate(batch_iter), desc=f"[{split}]"):
            # run through model
            if batch_i > 10000:
                break
            src, src_lengths = batch.src
            tgt = batch.tgt
            with torch.no_grad():
                hidden_states, attn = model(src,
                                            tgt,
                                            src_lengths,
                                            bptt=False,
                                            with_align=False)

            # save src idxs and hidden states
            pad_masks = (tgt[1:] != tgt_pad_idx).squeeze(2)
            cache.extend(hidden_states[pad_masks].cpu().numpy())
            cache_idxs.extend([(
                batch.indices[i].item(),
                pad_masks[:, i].sum().item(),
            ) for i in range(pad_masks.size(1))])

        cache = np.vstack(cache)
        # save the cache and the cache indices
        save_path = args['reporter']['results_path']
        print(save_path)
        print(cache.shape)
        np.save(os.path.join(save_path, f"cache.{split}.npy"), cache)
        with open(os.path.join(save_path, f"cache_idxs.{split}.csv"),
                  "w") as csvfile:
            csvfile.write("\n".join(
                [f"{idx},{length}" for idx, length in cache_idxs]))
        cache_idxs = []
        cache = []
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    load_str = opt.train_from if opt.train_from else opt.load_uncond_from
    if load_str:
        logger.info('Loading checkpoint from %s' % load_str)
        checkpoint = torch.load(load_str,
                                map_location=lambda storage, loc: storage)

        logger.info('Loading vocab from checkpoint at %s.' % load_str)
        vocab = checkpoint['vocab']

        if opt.train_from:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
        else:
            model_opt = opt
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.gpt2_params_path is not None:
        import tensorflow as tf
        import numpy as np
        # Taken from pytorch-pretrained-BERT:
        # Load weights from TF model
        logger.info("Loading TF GPT weights...")
        init_vars = tf.train.list_variables(opt.gpt2_params_path)
        names = []
        arrays = []
        for name, shape in init_vars:
            if opt.gpt_emb_only and ('wpe' not in name and 'wte' not in name):
                continue
            if opt.gpt_wpe_only and 'wpe' not in name:
                continue
            #print("Loading TF weight {} with shape {}".format(name, shape))
            array = tf.train.load_variable(opt.gpt2_params_path, name)
            names.append(name)
            arrays.append(array.squeeze())
        logger.info("Done.")

        if checkpoint is None:
            checkpoint = {'gpt2_params': zip(names, arrays)}
        else:
            checkpoint['gpt2_params'] = zip(names, arrays)

    if opt.encoder_from is not None:
        logger.info('Loading checkpoint with encoder from %s' %
                    opt.encoder_from)
        enc_checkpoint = torch.load(opt.encoder_from,
                                    map_location=lambda storage, loc: storage)
        enc_vocab = enc_checkpoint['vocab']
        if vocab['src'].base_field.vocab != enc_vocab['src'].base_field.vocab:
            raise ValueError(
                'encoder vocab and model vocab need to be identical it using pretrained encoder'
            )
        if checkpoint is None:
            checkpoint = {}
        checkpoint['enc_model'] = enc_checkpoint['model']

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    sides = ['tgt'] if opt.model_type == 'none' else ['src', 'tgt']
    for side in sides:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, lm_dec = _tally_parameters(model)
    n_params_t, enc_t, dec_t, lm_dec_t = _tally_parameters(model,
                                                           only_trainable=True)
    logger.info('encoder: %d (%d)' % (enc, enc_t))
    logger.info('decoder: %d (%d)' % (dec, dec_t))
    if opt.simple_fusion:
        logger.info('lm decoder: %d (%d)' % (lm_dec, lm_dec_t))

    logger.info('* number of parameters: %d (%d)' % (n_params, n_params_t))
    _check_save_model_path(opt)

    if not opt.train_from and opt.gpt2_params_path is not None:
        checkpoint = None

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #21
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    runTrain(opt)
예제 #22
0
    def train_impl(
        self,
        train_processed_data_dir: Path,
        val_processed_data_dir: Path,
        output_model_dir: Path,
    ) -> NoReturn:
        self.preprocess(train_processed_data_dir, val_processed_data_dir,
                        output_model_dir)

        from train import _get_parser as train_get_parser
        from train import ErrorHandler, batch_producer
        from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter
        from onmt.inputters.inputter import old_style_vocab, load_old_vocab
        import onmt.utils.distributed
        from onmt.utils.parse import ArgumentParser

        with IOUtils.cd(self.open_nmt_path):
            parser = train_get_parser()
            opt = parser.parse_args(
                f" -data {output_model_dir}/processed-data"
                f" -save_model {output_model_dir}/models/ckpt")
            opt.gpu_ranks = [0]
            opt.early_stopping = self.config.early_stopping_threshold
            opt.report_every = 200
            opt.valid_steps = 200
            opt.save_checkpoint_steps = 200
            opt.keep_checkpoint_max = self.config.ckpt_keep_max

            opt.optim = "adam"
            opt.learning_rate = self.config.learning_rate
            opt.max_grad_norm = self.config.max_grad_norm
            opt.batch_size = self.config.batch_size

            opt.encoder_type = self.config.encoder
            opt.decoder_type = self.config.decoder
            opt.dropout = [self.config.dropout]
            opt.src_word_vec_size = self.config.dim_embed
            opt.tgt_word_vec_size = self.config.dim_embed
            opt.layers = self.config.rnn_num_layers
            opt.enc_rnn_size = self.config.dim_encoder_hidden
            opt.dec_rnn_size = self.config.dim_decoder_hidden
            opt.__setattr__("num_srcs", len(self.config.get_src_types()))
            if self.config.use_attn:
                opt.global_attention = "general"
            else:
                opt.global_attention = "none"
            # end if
            if self.config.use_copy:
                opt.copy_attn = True
                opt.copy_attn_type = "general"
            # end if

            # train.main
            ArgumentParser.validate_train_opts(opt)
            ArgumentParser.update_model_opts(opt)
            ArgumentParser.validate_model_opts(opt)

            # Load checkpoint if we resume from a previous training.
            if opt.train_from:
                self.logger.info('Loading checkpoint from %s' % opt.train_from)
                checkpoint = torch.load(
                    opt.train_from, map_location=lambda storage, loc: storage)
                self.logger.info('Loading vocab from checkpoint at %s.' %
                                 opt.train_from)
                vocab = checkpoint['vocab']
            else:
                vocab = torch.load(opt.data + '.vocab.pt')
            # end if

            # check for code where vocab is saved instead of fields
            # (in the future this will be done in a smarter way)
            if old_style_vocab(vocab):
                fields = load_old_vocab(vocab,
                                        opt.model_type,
                                        dynamic_dict=opt.copy_attn)
            else:
                fields = vocab
            # end if

            if len(opt.data_ids) > 1:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                # end for
                train_iter = MultiSourceInputter.build_dataset_iter_multiple(
                    self.config.get_src_types(), train_shards, fields, opt)
            else:
                if opt.data_ids[0] is not None:
                    shard_base = "train_" + opt.data_ids[0]
                else:
                    shard_base = "train"
                # end if
                train_iter = MultiSourceInputter.build_dataset_iter(
                    self.config.get_src_types(), shard_base, fields, opt)
            # end if

            nb_gpu = len(opt.gpu_ranks)

            if opt.world_size > 1:
                queues = []
                mp = torch.multiprocessing.get_context('spawn')
                semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
                # Create a thread to listen for errors in the child processes.
                error_queue = mp.SimpleQueue()
                error_handler = ErrorHandler(error_queue)
                # Train with multiprocessing.
                procs = []
                for device_id in range(nb_gpu):
                    q = mp.Queue(opt.queue_size)
                    queues += [q]

                    def run(opt, device_id, error_queue, batch_queue,
                            semaphore):
                        """ run process """
                        try:
                            gpu_rank = onmt.utils.distributed.multi_init(
                                opt, device_id)
                            if gpu_rank != opt.gpu_ranks[device_id]:
                                raise AssertionError(
                                    "An error occurred in Distributed initialization"
                                )
                            self.train_single(opt, device_id, batch_queue,
                                              semaphore)
                        except KeyboardInterrupt:
                            pass  # killed by parent, do nothing
                        except Exception:
                            # propagate exception to parent process, keeping original traceback
                            import traceback
                            error_queue.put((opt.gpu_ranks[device_id],
                                             traceback.format_exc()))
                        # end try

                    # end def

                    procs.append(
                        mp.Process(target=run,
                                   args=(opt, device_id, error_queue, q,
                                         semaphore),
                                   daemon=True))
                    procs[device_id].start()
                    self.logger.info(" Starting process pid: %d  " %
                                     procs[device_id].pid)
                    error_handler.add_child(procs[device_id].pid)
                # end for
                producer = mp.Process(target=batch_producer,
                                      args=(
                                          train_iter,
                                          queues,
                                          semaphore,
                                          opt,
                                      ),
                                      daemon=True)
                producer.start()
                error_handler.add_child(producer.pid)

                for p in procs:
                    p.join()
                producer.terminate()

            elif nb_gpu == 1:  # case 1 GPU only
                self.train_single(output_model_dir, opt, 0)
            else:  # case only CPU
                self.train_single(output_model_dir, opt, -1)
            # end if
        # end with
        return
예제 #23
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        if 'vocab' in checkpoint:
            logger.info('Loading vocab from checkpoint at %s.' %
                        opt.train_from)
            vocab = checkpoint['vocab']
        else:
            vocab = torch.load(opt.data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
예제 #24
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    if opt.train_from != '':
        raise Exception(
            'train_from will be set automatically to the latest model, you should not set it manually'
        )

    # set gpu ranks automatically if not specified
    if len(opt.gpu_ranks) == 0:
        opt.gpu_ranks = list(range(opt.world_size))

    # Set train_from to latest checkpoint if it exists
    file_list = glob.glob(opt.save_model + '*.pt')
    if len(os.listdir(os.path.dirname(
            opt.save_model))) > 0 and len(file_list) == 0:
        raise Exception(
            'save_model directory is not empty but no pretrained models found')
    if len(file_list) > 0:
        ckpt_nos = list(
            map(lambda x: int(x.split('_')[-1].split('.')[0]), file_list))
        ckpt_no = max(ckpt_nos)
        opt.train_from = opt.save_model + '_' + str(ckpt_no) + '.pt'
        print(opt.train_from)
        assert os.path.exists(opt.train_from)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
예제 #25
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # @memray: a temporary workaround, as well as train_single.py line 78
    if opt.model_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(
                    use_vocab=False, dtype=torch.long,
                    postprocessing=make_tgt, sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)
    print(os.environ['PATH'])

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(mp.Process(target=run, args=(
                opt, device_id, error_queue, q, semaphore), daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(train_iter, queues, semaphore, opt,),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:   # case only CPU
        single_main(opt, -1)
예제 #26
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    logger.info('Loading alignment.')
    lemma_aligns = open(model_opt.lemma_align, 'rb').readlines()
    src_stoi = vocab['src'].base_field.vocab.stoi
    lemma_stoi = vocab['word_topic'].base_field.vocab.stoi
    w2l = {}
    word_to_lemma = []
    for pair in lemma_aligns:
        pair = pair.strip().split()
        w2l[src_stoi[pair[0].decode('utf-8')]] = \
            lemma_stoi[pair[1].decode('utf-8')]
    w2l[src_stoi['unk']] = lemma_stoi['unk']
    for index in range(len(vocab['src'].base_field.vocab.itos)):
        if index in w2l:
            word_to_lemma.append(w2l[index])
        else:
            word_to_lemma.append(w2l[lemma_stoi['unk']])
    word_to_lemma = torch.tensor(word_to_lemma)
    logger.info('Loading topic matrix')
    if device_id >= 0:
        topic_matrix = torch.load(opt.topic_matrix,
                                  map_location=torch.device(device_id))
    else:
        topic_matrix = torch.load(opt.topic_matrix)
    if opt.model_dtype == 'fp16':
        topic_matrix = topic_matrix.half()
    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab
    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(topic_matrix,
                  word_to_lemma,
                  train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #27
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
    else:
        checkpoint = None
        model_opt = opt
    vocab = torch.load(opt.data + '/vocab.pt')

    if opt.train_from and not opt.reset_optim in {'score', 'all'}:
        valid_loss = checkpoint['valid_loss']
        test_score = checkpoint['test_score']
    else:
        valid_loss = 100000
        test_score = -1

    fields = vocab
    # Report src and tgt vocab sizes, including for features
    for idx in fields.keys():
        logger.info(' * %s feat size = %d' % (idx, len(fields[idx])))

    # Detect device
    gpu = use_gpu(opt)
    if gpu:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    # Build model.
    if opt.model_type in {'single', 'nmt', 'tts', 'asr', 'vocoder'}:
        model = build_model(model_opt, opt, fields, device, checkpoint)
        n_params, enc, dec = _tally_parameters(model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)
    else:
        model = build_multi_model(model_opt, opt, fields, device, checkpoint)

    #_check_save_model_path(opt)
    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)
    #if opt.adversarial:
    optim_ad, loss_ad = adversarial(opt.adversarial, model, model_opt, vocab)
    trainer = build_trainer(fields,
                            opt,
                            device_id,
                            model,
                            optim,
                            model_saver=model_saver,
                            optim_ad=optim_ad,
                            loss_ad=loss_ad)

    train_set, valid_set, test_set = load_dataset(opt.data)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_set,
                  opt.batch_size,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_set=valid_set,
                  valid_batch_size=min(opt.valid_batch_size, opt.batch_size),
                  valid_steps=opt.valid_steps,
                  test_set=test_set,
                  valid_loss=valid_loss,
                  test_score=test_score)

    trainer.report_manager.tensorboard_writer.close()
예제 #28
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    if opt.local_rank != -1:
        torch.cuda.set_device(opt.local_rank)
        device = torch.device("cuda", opt.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        device_id = opt.local_rank
        world_size = torch.distributed.get_world_size()
    else:
        if device_id == -1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda", device_id)
    if opt.local_rank > 0:
        logger.disabled = True
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if opt.bert_kd:
        src_vocab = vocab['src'].fields[0][1].vocab.stoi
        tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi
        assert 0 < opt.kd_topk <= 128
        train_dataset = BertKdDataset(opt.data_db,
                                      opt.bert_dump,
                                      src_vocab,
                                      tgt_vocab,
                                      max_len=150,
                                      k=opt.kd_topk)
        BUCKET_SIZE = 8192
        if True or opt.local_rank == -1 and opt.world_size == 1:
            train_sampler = TokenBucketSampler(train_dataset.keys,
                                               BUCKET_SIZE,
                                               opt.batch_size,
                                               batch_multiple=1)
        else:
            assert False  # seems like it's handled in training loop
            train_sampler = DistributedTokenBucketSampler(world_size,
                                                          device_id,
                                                          train_dataset.keys,
                                                          BUCKET_SIZE,
                                                          opt.batch_size,
                                                          batch_multiple=1)
        train_loader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  num_workers=4,
                                  collate_fn=BertKdDataset.pad_collate)
        train_iter = cycle_loader(train_loader, device)
    else:
        train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        if trainer.report_manager.tensorboard_writer:
            trainer.report_manager.tensorboard_writer.close()
예제 #29
0
    def __init__(self, model_dir):

        # Model dir
        self._model_dir = os.path.abspath(model_dir)
        if not os.path.isdir(self._model_dir):
            msg = f"{model_dir} doesn't exists'"
            raise ValueError(msg)

        # Extended model
        self._extended_model = ExtendedModel(model_dir)

        # Config
        self._config = self._extended_model.config

        # Options
        self._opts = self._config.opts

        # Get the model options
        model_path = self._opts.models[0]
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        self._model_opts = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
        ArgumentParser.update_model_opts(self._model_opts)
        ArgumentParser.validate_model_opts(self._model_opts)

        # Train_steps
        self._train_steps = self._model_opts.train_steps

        # Extract vocabulary
        vocab = checkpoint['vocab']
        if inputters.old_style_vocab(vocab):
            self._fields = inputters.load_old_vocab(
                vocab,
                self._opts.data_type,
                dynamic_dict=self._model_opts.copy_attn)
        else:
            self._fields = vocab

        # Build model
        self._model = build_base_model(self._model_opts, self._fields,
                                       use_gpu(self._opts), checkpoint,
                                       self._opts.gpu)

        if self._opts.fp32:
            self._model.float()

        #Translator
        scorer = GNMTGlobalScorer.from_opt(self._opts)

        self.translator = OnmtxTranslator.from_opt(
            self._model,
            self._fields,
            self._opts,
            self._model_opts,
            global_scorer=scorer,
            out_file=None,
            report_score=False,
            logger=None,
        )

        # Create trainer
        self._optim = Optimizer.from_opt(self._model,
                                         self._opts,
                                         checkpoint=checkpoint)

        device_id = -1  # TODO Handle GPU
        self.trainer = build_trainer(self._opts, device_id, self._model,
                                     self._fields, self._optim)
예제 #30
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, nontrainable = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
예제 #31
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)

    # save training settings
    if opt.log_file:
        shutil.copy2(opt.config, opt.exp_dir)
    logger.info(vars(opt))

    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        # added by @memray for multiple datasets
        if opt.vocab and opt.vocab != 'none':
            vocab = torch.load(opt.vocab)
        elif opt.encoder_type == 'pretrained':
            vocab = None
        else:
            vocab = None

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # @memray: a temporary workaround, as well as train.py line 43
    if opt.model_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(
                    use_vocab=False, dtype=torch.long,
                    postprocessing=make_tgt, sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    tokenizer = None
    if opt.pretrained_tokenizer:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(fields, opt, tokenizer)


    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            # added by @memray, for loading multiple datasets
            if opt.multi_dataset:
                shard_base = "train"
                train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer)
            else:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer)
        else:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    if opt.valid:
        valid_iter = build_dataset_iter(
            "valid", fields, opt, is_train=False)
    else:
        valid_iter = None

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
예제 #32
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
def validate(opt, device_id=0):
    configure_process(opt, device_id)
    configure_process
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']

        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type,
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    #_check_save_model_path(opt)

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    tgt_field = dict(fields)["tgt"].base_field
    valid_loss = onmt.utils.loss.build_loss_compute(model,
                                                    tgt_field,
                                                    opt,
                                                    train=False)

    model.eval()

    with torch.no_grad():
        stats = onmt.utils.Statistics()

        for batch in valid_iter:

            src, src_lengths = batch.src if isinstance(batch.src, tuple) \
                                   else (batch.src, None)
            tgt = batch.tgt

            # F-prop through the model.
            outputs, attns = model(src, tgt, src_lengths)

            # Compute loss.
            _, batch_stats = valid_loss(batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

    print('n words:  %d' % stats.n_words)
    print('Validation perplexity: %g' % stats.ppl())
    print('Validation accuracy: %g' % stats.accuracy())
    print('Validation avg attention entropy: %g' % stats.attn_entropy())