示例#1
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        if 'vocab' in checkpoint:
            logger.info('Loading vocab from checkpoint at %s.' %
                        opt.train_from)
            vocab = checkpoint['vocab']
        else:
            vocab = torch.load(opt.data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
def train(opt):
    ArgumentParser.validate_train_opts(opt)

    set_random_seed(opt.seed, False)

    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    patch_fields(opt, fields)

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)

        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:
        single_main(opt, 0)
    else:
        single_main(opt, -1)
示例#3
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, nontrainable = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
示例#4
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    if opt.train_from != '':
        raise Exception(
            'train_from will be set automatically to the latest model, you should not set it manually'
        )

    # set gpu ranks automatically if not specified
    if len(opt.gpu_ranks) == 0:
        opt.gpu_ranks = list(range(opt.world_size))

    # Set train_from to latest checkpoint if it exists
    file_list = glob.glob(opt.save_model + '*.pt')
    if len(os.listdir(os.path.dirname(
            opt.save_model))) > 0 and len(file_list) == 0:
        raise Exception(
            'save_model directory is not empty but no pretrained models found')
    if len(file_list) > 0:
        ckpt_nos = list(
            map(lambda x: int(x.split('_')[-1].split('.')[0]), file_list))
        ckpt_no = max(ckpt_nos)
        opt.train_from = opt.save_model + '_' + str(ckpt_no) + '.pt'
        print(opt.train_from)
        assert os.path.exists(opt.train_from)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
示例#5
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.teacher_model_path:
        logger.info('Loading teacher model from {path}'.format(
            path=opt.teacher_model_path))
        teacher_model_ckpt = torch.load(
            opt.teacher_model_path, map_location=lambda storage, loc: storage)

        teacher_model_opt = ArgumentParser.ckpt_model_opts(
            teacher_model_ckpt['opt'])
        ArgumentParser.update_model_opts(teacher_model_opt)
        ArgumentParser.validate_model_opts(teacher_model_opt)
        logger.info('Loading vocab from checkpoint at {path}'.format(
            path=opt.teacher_model_path))
        teacher_vocab = teacher_model_ckpt['vocab']

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab
        teacher_fields = teacher_vocab if opt.teacher_model_path else None

    # patch for fields that may be missing in old data/model
    # patch_fields(opt, fields)

    # Report src and tgt vocab sizes, including for features
    report_vocab_size(fields)
    if teacher_fields is not None:
        report_vocab_size(teacher_fields)

    # Build model.
    fields_opt = {"original": fields, "teacher": teacher_fields}
    model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint)
    # model = build_model(model_opt, opt, fields, checkpoint)
    teacher_model = build_model(
        teacher_model_opt, teacher_model_opt, teacher_fields,
        teacher_model_ckpt) if opt.teacher_model_path else None

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    if teacher_model is not None:
        n_params, enc, dec = _tally_parameters(teacher_model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)
        _check_save_model_path(teacher_model_opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    # model_saver = build_model_saver(model_opt, opt, model, fields, optim)
    model_saver = custom_model_saver.build_model_saver(model_opt, opt, model,
                                                       fields_opt, optim)

    tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \
        else dict(fields)["tgt"].base_field
    sos_id = tgt_field.vocab.stoi[tgt_field.init_token]

    if teacher_model is not None and opt.word_sampling:
        sampler = Emulator(teacher_model,
                           teacher_fields,
                           device_id,
                           max_length=50,
                           random_sampling_topk=5)
    else:
        sampler = None

    if teacher_model is not None:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                teacher_fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)
    else:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  sos_id=sos_id,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
示例#6
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)

    # save training settings
    if opt.log_file:
        shutil.copy2(opt.config, opt.exp_dir)
    logger.info(vars(opt))

    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        # added by @memray for multiple datasets
        if opt.vocab and opt.vocab != 'none':
            vocab = torch.load(opt.vocab)
        elif opt.encoder_type == 'pretrained':
            vocab = None
        else:
            vocab = None

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # @memray: a temporary workaround, as well as train.py line 43
    if opt.model_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(
                    use_vocab=False, dtype=torch.long,
                    postprocessing=make_tgt, sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    tokenizer = None
    if opt.pretrained_tokenizer:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(fields, opt, tokenizer)


    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            # added by @memray, for loading multiple datasets
            if opt.multi_dataset:
                shard_base = "train"
                train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer)
            else:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer)
        else:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    if opt.valid:
        valid_iter = build_dataset_iter(
            "valid", fields, opt, is_train=False)
    else:
        valid_iter = None

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
示例#7
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # @Memray, check the dir existence beforehand to avoid path conflicting errors,
    #   and set save_model, tensorboard_log_dir, wandb_log_dir if not exist
    train_single._check_save_model_path(opt)
    if not os.path.exists(opt.tensorboard_log_dir):
        os.makedirs(opt.tensorboard_log_dir)

    # Scan previous checkpoint to resume training
    latest_step = 0
    latest_ckpt = None
    for subdir, dirs, filenames in os.walk(opt.exp_dir):
        for filename in sorted(filenames):
            if not filename.endswith('.pt'):
                continue
            step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')])
            if step > latest_step:
                latest_ckpt = os.path.join(subdir, filename)
                latest_step = step
    # if not saved in the exp folder, check opt.save_model
    if latest_ckpt is None and opt.save_model is not None:
        save_model_dir = os.path.dirname(os.path.abspath(opt.save_model))
        model_prefix = opt.save_model[opt.save_model.rfind(os.path.sep) + 1:]
        for subdir, dirs, filenames in os.walk(save_model_dir):
            for filename in sorted(filenames):
                if not filename.endswith('.pt'):
                    continue
                if not filename.startswith(model_prefix):
                    continue
                step = int(filename[filename.rfind('_') +
                                    1:filename.rfind('.pt')])
                if step > latest_step:
                    latest_ckpt = os.path.join(subdir, filename)
                    latest_step = step
    if latest_ckpt is not None:
        logger.info("A previous checkpoint is found, train from it: %s" %
                    latest_ckpt)
        setattr(opt, 'train_from', latest_ckpt)
        setattr(opt, 'reset_optim', 'none')

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    elif opt.vocab and opt.vocab != 'none':
        # added by @memray for multiple datasets
        vocab = torch.load(opt.vocab)
        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            vocab = load_old_vocab(vocab,
                                   opt.model_type,
                                   dynamic_dict=opt.copy_attn)
    elif opt.encoder_type == 'pretrained':
        vocab = None
    else:
        vocab = None

    fields = vocab

    # @memray: a temporary workaround, as well as train_single.py line 78
    if fields and opt.data_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(use_vocab=False,
                                    dtype=torch.long,
                                    postprocessing=make_tgt,
                                    sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    # @memray reload fields for news dataset and pretrained models
    tokenizer = None
    if opt.pretrained_tokenizer is not None:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer,
                                              opt.cache_dir,
                                              opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(opt, tokenizer=tokenizer)
    # elif opt.data_type == 'keyphrase':
    #     fields = reload_keyphrase_fields(opt, tokenizer=tokenizer)

    if len(opt.data_ids) > 1:
        # added by @memray, for loading multiple datasets
        if opt.multi_dataset:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base,
                                            fields,
                                            opt,
                                            multi=True)
        else:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
示例#8
0
def main(opt,
         device_id,
         batch_queue=None,
         semaphore=None,
         train_iter=None,
         passed_fields=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        if opt.use_opt_from_trained:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        else:
            model_opt = opt
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    aux_fields = None
    if passed_fields is not None:
        fields = passed_fields['main']
        aux_fields = passed_fields['crosslingual']
    elif old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt,
                        opt,
                        fields,
                        checkpoint,
                        aux_fields=aux_fields)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    if opt.almt_only:
        almt = model.encoder.embeddings.almt_layers['mapping']
        logger.info('Only training the alignment mapping.')
        optim = Optimizer.from_opt(almt, opt, checkpoint=checkpoint)
    else:
        optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt,
                                    opt,
                                    model,
                                    fields,
                                    optim,
                                    aux_fields=aux_fields)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver,
                            aux_fields=aux_fields)

    if train_iter is not None:
        pass  # NOTE Use the passed one.
    elif batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    cl_valid_iter = None
    if opt.crosslingual:
        valid_iter = build_dataset_iter("valid",
                                        fields,
                                        opt,
                                        is_train=False,
                                        task_cls=Eat2PlainMonoTask)
        if opt.crosslingual_dev_data:
            # NOTE I used 'train' to prepare this in `eat_prepare.sh`, so I use 'train' here as well.
            cl_valid_iter = build_dataset_iter(
                'train',
                fields,
                opt,
                is_train=False,
                data_attr='crosslingual_dev_data',
                task_cls=Eat2PlainCrosslingualTask)
        # NOTE This is for the second eat->plain task.
        aux_valid_iter = build_dataset_iter('valid',
                                            fields,
                                            opt,
                                            is_train=False,
                                            data_attr='aux_train_data',
                                            task_cls=Eat2PlainMonoTask)
        valid_iters = [valid_iter, aux_valid_iter]
    else:
        valid_iters = [
            build_dataset_iter("valid", fields, opt, is_train=False)
        ]

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iters=valid_iters,
                  valid_steps=opt.valid_steps,
                  cl_valid_iter=cl_valid_iter)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
    def train_single(self, output_model_dir: Path, opt, device_id, batch_queue=None, semaphore=None):
        from roosterize.ml.onmt.CustomTrainer import CustomTrainer
        from onmt.inputters.inputter import build_dataset_iter, load_old_vocab, old_style_vocab, build_dataset_iter_multiple
        from onmt.model_builder import build_model
        from onmt.train_single import configure_process, _tally_parameters, _check_save_model_path
        from onmt.models import build_model_saver
        from onmt.utils.optimizers import Optimizer
        from onmt.utils.parse import ArgumentParser

        configure_process(opt, device_id)
        assert len(opt.accum_count) == len(opt.accum_steps), 'Number of accum_count values must match number of accum_steps'
        # Load checkpoint if we resume from a previous training.
        if opt.train_from:
            self.logger.info('Loading checkpoint from %s' % opt.train_from)
            checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage)
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
            self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
            vocab = checkpoint['vocab']
        else:
            checkpoint = None
            model_opt = opt
            vocab = torch.load(opt.data + '.vocab.pt')
        # end if

        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn)
        else:
            fields = vocab
        # end if

        # Report src and tgt vocab sizes, including for features
        for side in ['src', 'tgt']:
            f = fields[side]
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(side, f)]
            # end try
            for sn, sf in f_iter:
                if sf.use_vocab:  self.logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))
            # end for

        # Build model
        model = build_model(model_opt, opt, fields, checkpoint)
        n_params, enc, dec = _tally_parameters(model)
        self.logger.info('encoder: %d' % enc)
        self.logger.info('decoder: %d' % dec)
        self.logger.info('* number of parameters: %d' % n_params)
        _check_save_model_path(opt)

        # Build optimizer.
        optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

        # Build model saver
        model_saver = build_model_saver(model_opt, opt, model, fields, optim)

        trainer = CustomTrainer.build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver)

        if batch_queue is None:
            if len(opt.data_ids) > 1:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                # end for
                train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
            else:
                if opt.data_ids[0] is not None:
                    shard_base = "train_" + opt.data_ids[0]
                else:
                    shard_base = "train"
                # end if
                train_iter = build_dataset_iter(shard_base, fields, opt)
            # end if
        else:
            assert semaphore is not None, "Using batch_queue requires semaphore as well"

            def _train_iter():
                while True:
                    batch = batch_queue.get()
                    semaphore.release()
                    yield batch
                # end while
            # end def

            train_iter = _train_iter()
        # end if

        valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

        if len(opt.gpu_ranks):
            self.logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
        else:
            self.logger.info('Starting training on CPU, could be very slow')
        # end if
        train_steps = opt.train_steps
        if opt.single_pass and train_steps > 0:
            self.logger.warning("Option single_pass is enabled, ignoring train_steps.")
            train_steps = 0
        # end if

        trainer.train(
            train_iter,
            train_steps,
            save_checkpoint_steps=opt.save_checkpoint_steps,
            valid_iter=valid_iter,
            valid_steps=opt.valid_steps)
        time_begin = trainer.report_manager.start_time
        time_end = time.time()

        if opt.tensorboard:  trainer.report_manager.tensorboard_writer.close()

        # Dump train metrics
        train_history = trainer.report_manager.get_joint_history()
        train_metrics = {
            "time_begin": time_begin,
            "time_end": time_end,
            "time": time_end - time_begin,
            "train_history": train_history,
        }
        IOUtils.dump(output_model_dir/"train-metrics.json", train_metrics, IOUtils.Format.jsonNoSort)

        # Get the best step, depending on the lowest val_xent (cross entropy)
        best_loss = min([th["val_xent"] for th in train_history])
        best_step = [th["step"] for th in train_history if th["val_xent"] == best_loss][-1]  # Take the last if multiple
        IOUtils.dump(output_model_dir/"best-step.json", best_step, IOUtils.Format.json)
        return
    def train_impl(self,
            train_processed_data_dir: Path,
            val_processed_data_dir: Path,
            output_model_dir: Path,
    ) -> NoReturn:
        from train import _get_parser as train_get_parser
        from train import ErrorHandler, batch_producer
        from onmt.inputters.inputter import old_style_vocab, load_old_vocab, build_dataset_iter, build_dataset_iter_multiple
        import onmt.utils.distributed
        from onmt.utils.parse import ArgumentParser

        with IOUtils.cd(self.open_nmt_path):
            parser = train_get_parser()
            opt = parser.parse_args(
                f" -data {output_model_dir}/processed-data"
                f" -save_model {output_model_dir}/models/ckpt"
            )
            opt.gpu_ranks = [0]
            opt.early_stopping = self.config.early_stopping_threshold
            opt.report_every = 200
            opt.valid_steps = 200
            opt.save_checkpoint_steps = 200
            opt.keep_checkpoint_max = self.config.ckpt_keep_max

            opt.optim = "adam"
            opt.learning_rate = self.config.learning_rate
            opt.max_grad_norm = self.config.max_grad_norm
            opt.batch_size = self.config.batch_size

            opt.encoder_type = self.config.encoder
            opt.decoder_type = self.config.decoder
            opt.dropout = [self.config.dropout]
            opt.src_word_vec_size = self.config.dim_embed
            opt.tgt_word_vec_size = self.config.dim_embed
            opt.layers = self.config.rnn_num_layers
            opt.enc_rnn_size = self.config.dim_encoder_hidden
            opt.dec_rnn_size = self.config.dim_decoder_hidden
            if self.config.use_attn:
                opt.global_attention = "general"
            else:
                opt.global_attention = "none"
            # end if
            if self.config.use_copy:
                opt.copy_attn = True
                opt.copy_attn_type = "general"
            # end if

            # train.main, one gpu case
            ArgumentParser.validate_train_opts(opt)
            ArgumentParser.update_model_opts(opt)
            ArgumentParser.validate_model_opts(opt)

            # Load checkpoint if we resume from a previous training.
            if opt.train_from:
                self.logger.info('Loading checkpoint from %s' % opt.train_from)
                checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage)
                self.logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
                vocab = checkpoint['vocab']
            else:
                vocab = torch.load(opt.data + '.vocab.pt')
            # end if

            # check for code where vocab is saved instead of fields
            # (in the future this will be done in a smarter way)
            if old_style_vocab(vocab):
                fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn)
            else:
                fields = vocab
            # end if

            if len(opt.data_ids) > 1:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                # end for
                train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
            else:
                if opt.data_ids[0] is not None:
                    shard_base = "train_" + opt.data_ids[0]
                else:
                    shard_base = "train"
                # end if
                train_iter = build_dataset_iter(shard_base, fields, opt)
            # end if

            nb_gpu = len(opt.gpu_ranks)

            if opt.world_size > 1:
                queues = []
                mp = torch.multiprocessing.get_context('spawn')
                semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
                # Create a thread to listen for errors in the child processes.
                error_queue = mp.SimpleQueue()
                error_handler = ErrorHandler(error_queue)
                # Train with multiprocessing.
                procs = []
                for device_id in range(nb_gpu):
                    q = mp.Queue(opt.queue_size)
                    queues += [q]

                    def run(opt, device_id, error_queue, batch_queue, semaphore):
                        """ run process """
                        try:
                            gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
                            if gpu_rank != opt.gpu_ranks[device_id]:
                                raise AssertionError("An error occurred in Distributed initialization")
                            self.train_single(opt, device_id, batch_queue, semaphore)
                        except KeyboardInterrupt:
                            pass  # killed by parent, do nothing
                        except Exception:
                            # propagate exception to parent process, keeping original traceback
                            import traceback
                            error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))
                        # end try
                    # end def

                    procs.append(mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True))
                    procs[device_id].start()
                    self.logger.info(" Starting process pid: %d  " % procs[device_id].pid)
                    error_handler.add_child(procs[device_id].pid)
                # end for
                producer = mp.Process(target=batch_producer,args=(train_iter, queues, semaphore, opt,), daemon=True)
                producer.start()
                error_handler.add_child(producer.pid)

                for p in procs:  p.join()
                producer.terminate()

            elif nb_gpu == 1:  # case 1 GPU only
                self.train_single(output_model_dir, opt, 0)
            else:  # case only CPU
                self.train_single(output_model_dir, opt, -1)
            # end if
        # end with
        return
示例#11
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    aux_vocab = None
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
        if opt.crosslingual:
            aux_vocab = checkpoint['aux_vocab']
    elif opt.crosslingual:
        assert opt.crosslingual in ['old', 'lm']
        vocab = torch.load(opt.data + '.vocab.pt')
        aux_vocab = torch.load(opt.aux_train_data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    def get_fields(vocab):
        if old_style_vocab(vocab):
            return load_old_vocab(vocab,
                                  opt.model_type,
                                  dynamic_dict=opt.copy_attn)
        else:
            return vocab

    fields = get_fields(vocab)
    aux_fields = None
    if opt.crosslingual:
        aux_fields = get_fields(aux_vocab)

    if opt.crosslingual:
        if opt.crosslingual == 'old':
            aeq(len(opt.eat_formats), 3)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data',
                 Eat2PlainCrosslingualTask, 'crosslingual', opt.eat_format[2])
            ]
        else:
            aeq(len(opt.eat_formats), 4)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', fields, 'data', EatLMMonoTask, 'lm',
                 opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[2]),
                ('train', aux_fields, 'aux_train_data', EatLMCrosslingualTask,
                 'crosslingual', opt.eat_formats[3])
            ]
        train_iter = build_crosslingual_dataset_iter(fields_info, opt)
    elif len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    else:
        device_id = 0 if nb_gpu == 1 else -1
        # NOTE Only pass train_iter in my crosslingual mode.
        train_iter = train_iter if opt.crosslingual else None
        passed_fields = {
            'main': fields,
            'crosslingual': aux_fields
        } if opt.crosslingual else None
        single_main(opt,
                    device_id,
                    train_iter=train_iter,
                    passed_fields=passed_fields)
示例#12
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # @memray: a temporary workaround, as well as train_single.py line 78
    if opt.model_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(
                    use_vocab=False, dtype=torch.long,
                    postprocessing=make_tgt, sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)
    print(os.environ['PATH'])

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(mp.Process(target=run, args=(
                opt, device_id, error_queue, q, semaphore), daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(train_iter, queues, semaphore, opt,),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:   # case only CPU
        single_main(opt, -1)