예제 #1
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        mp = torch.multiprocessing.get_context('spawn')
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            procs.append(
                mp.Process(target=run,
                           args=(
                               opt,
                               device_id,
                               error_queue,
                           ),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        for p in procs:
            p.join()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
예제 #2
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)


    runTrain(opt)
예제 #3
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    if opt.gpu > -1:  # case GPU
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
예제 #4
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    segment_token_idx = None
    if opt.use_segments:
        segment_token_idx = vocab['tgt'].base_field.vocab.stoi['.']
    opt.segment_token_idx = segment_token_idx

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    logger.info('Loading checkpoint from %s' % opt.train_from)
    checkpoint = torch.load(opt.train_from,
                            map_location=lambda storage, loc: storage)
    model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
    ArgumentParser.update_model_opts(model_opt)
    ArgumentParser.validate_model_opts(model_opt)
    logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
    vocab = checkpoint['vocab']

    fields = vocab
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    model = build_model(model_opt, opt, fields, checkpoint)
    pdb.set_trace()
예제 #5
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        data = torch.load(opt.data)

    single_main(opt, opt.gpu, data)
예제 #6
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        if 'vocab' in checkpoint:
            logger.info('Loading vocab from checkpoint at %s.' %
                        opt.train_from)
            vocab = checkpoint['vocab']
        else:
            vocab = torch.load(opt.data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
def train(opt):
    ArgumentParser.validate_train_opts(opt)

    set_random_seed(opt.seed, False)

    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    patch_fields(opt, fields)

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)

        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:
        single_main(opt, 0)
    else:
        single_main(opt, -1)
예제 #8
0
def train(opt):
    init_logger(opt.log_file)
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    checkpoint, fields, transforms_cls = _init_train(
        opt)  # Datasets and transformations (Both dicts)
    train_process = partial(single_main,
                            fields=fields,
                            transforms_cls=transforms_cls,
                            checkpoint=checkpoint)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:

        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=consumer,
                           args=(train_process, opt, device_id, error_queue, q,
                                 semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producers = []
        # This does not work if we merge with the first loop, not sure why
        for device_id in range(nb_gpu):
            # Get the iterator to generate from
            train_iter = _build_train_iter(opt,
                                           fields,
                                           transforms_cls,
                                           stride=nb_gpu,
                                           offset=device_id)
            producer = mp.Process(target=batch_producer,
                                  args=(
                                      train_iter,
                                      queues[device_id],
                                      semaphore,
                                      opt,
                                  ),
                                  daemon=True)
            producers.append(producer)
            producers[device_id].start()
            logger.info(" Starting producer process pid: {}  ".format(
                producers[device_id].pid))
            error_handler.add_child(producers[device_id].pid)

        for p in procs:
            p.join()
        # Once training is done, we can terminate the producers
        for p in producers:
            p.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        # TODO make possible for custom GPU id. Also replace assert at utils/parse.py line 275
        train_process(opt, device_id=opt.gpu_ranks[0])
    else:  # case only CPU
        train_process(opt, device_id=-1)
예제 #9
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    if opt.train_from != '':
        raise Exception(
            'train_from will be set automatically to the latest model, you should not set it manually'
        )

    # set gpu ranks automatically if not specified
    if len(opt.gpu_ranks) == 0:
        opt.gpu_ranks = list(range(opt.world_size))

    # Set train_from to latest checkpoint if it exists
    file_list = glob.glob(opt.save_model + '*.pt')
    if len(os.listdir(os.path.dirname(
            opt.save_model))) > 0 and len(file_list) == 0:
        raise Exception(
            'save_model directory is not empty but no pretrained models found')
    if len(file_list) > 0:
        ckpt_nos = list(
            map(lambda x: int(x.split('_')[-1].split('.')[0]), file_list))
        ckpt_no = max(ckpt_nos)
        opt.train_from = opt.save_model + '_' + str(ckpt_no) + '.pt'
        print(opt.train_from)
        assert os.path.exists(opt.train_from)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
예제 #10
0
    def train_impl(
        self,
        train_processed_data_dir: Path,
        val_processed_data_dir: Path,
        output_model_dir: Path,
    ) -> NoReturn:
        self.preprocess(train_processed_data_dir, val_processed_data_dir,
                        output_model_dir)

        from train import _get_parser as train_get_parser
        from train import ErrorHandler, batch_producer
        from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter
        from onmt.inputters.inputter import old_style_vocab, load_old_vocab
        import onmt.utils.distributed
        from onmt.utils.parse import ArgumentParser

        with IOUtils.cd(self.open_nmt_path):
            parser = train_get_parser()
            opt = parser.parse_args(
                f" -data {output_model_dir}/processed-data"
                f" -save_model {output_model_dir}/models/ckpt")
            opt.gpu_ranks = [0]
            opt.early_stopping = self.config.early_stopping_threshold
            opt.report_every = 200
            opt.valid_steps = 200
            opt.save_checkpoint_steps = 200
            opt.keep_checkpoint_max = self.config.ckpt_keep_max

            opt.optim = "adam"
            opt.learning_rate = self.config.learning_rate
            opt.max_grad_norm = self.config.max_grad_norm
            opt.batch_size = self.config.batch_size

            opt.encoder_type = self.config.encoder
            opt.decoder_type = self.config.decoder
            opt.dropout = [self.config.dropout]
            opt.src_word_vec_size = self.config.dim_embed
            opt.tgt_word_vec_size = self.config.dim_embed
            opt.layers = self.config.rnn_num_layers
            opt.enc_rnn_size = self.config.dim_encoder_hidden
            opt.dec_rnn_size = self.config.dim_decoder_hidden
            opt.__setattr__("num_srcs", len(self.config.get_src_types()))
            if self.config.use_attn:
                opt.global_attention = "general"
            else:
                opt.global_attention = "none"
            # end if
            if self.config.use_copy:
                opt.copy_attn = True
                opt.copy_attn_type = "general"
            # end if

            # train.main
            ArgumentParser.validate_train_opts(opt)
            ArgumentParser.update_model_opts(opt)
            ArgumentParser.validate_model_opts(opt)

            # Load checkpoint if we resume from a previous training.
            if opt.train_from:
                self.logger.info('Loading checkpoint from %s' % opt.train_from)
                checkpoint = torch.load(
                    opt.train_from, map_location=lambda storage, loc: storage)
                self.logger.info('Loading vocab from checkpoint at %s.' %
                                 opt.train_from)
                vocab = checkpoint['vocab']
            else:
                vocab = torch.load(opt.data + '.vocab.pt')
            # end if

            # check for code where vocab is saved instead of fields
            # (in the future this will be done in a smarter way)
            if old_style_vocab(vocab):
                fields = load_old_vocab(vocab,
                                        opt.model_type,
                                        dynamic_dict=opt.copy_attn)
            else:
                fields = vocab
            # end if

            if len(opt.data_ids) > 1:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                # end for
                train_iter = MultiSourceInputter.build_dataset_iter_multiple(
                    self.config.get_src_types(), train_shards, fields, opt)
            else:
                if opt.data_ids[0] is not None:
                    shard_base = "train_" + opt.data_ids[0]
                else:
                    shard_base = "train"
                # end if
                train_iter = MultiSourceInputter.build_dataset_iter(
                    self.config.get_src_types(), shard_base, fields, opt)
            # end if

            nb_gpu = len(opt.gpu_ranks)

            if opt.world_size > 1:
                queues = []
                mp = torch.multiprocessing.get_context('spawn')
                semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
                # Create a thread to listen for errors in the child processes.
                error_queue = mp.SimpleQueue()
                error_handler = ErrorHandler(error_queue)
                # Train with multiprocessing.
                procs = []
                for device_id in range(nb_gpu):
                    q = mp.Queue(opt.queue_size)
                    queues += [q]

                    def run(opt, device_id, error_queue, batch_queue,
                            semaphore):
                        """ run process """
                        try:
                            gpu_rank = onmt.utils.distributed.multi_init(
                                opt, device_id)
                            if gpu_rank != opt.gpu_ranks[device_id]:
                                raise AssertionError(
                                    "An error occurred in Distributed initialization"
                                )
                            self.train_single(opt, device_id, batch_queue,
                                              semaphore)
                        except KeyboardInterrupt:
                            pass  # killed by parent, do nothing
                        except Exception:
                            # propagate exception to parent process, keeping original traceback
                            import traceback
                            error_queue.put((opt.gpu_ranks[device_id],
                                             traceback.format_exc()))
                        # end try

                    # end def

                    procs.append(
                        mp.Process(target=run,
                                   args=(opt, device_id, error_queue, q,
                                         semaphore),
                                   daemon=True))
                    procs[device_id].start()
                    self.logger.info(" Starting process pid: %d  " %
                                     procs[device_id].pid)
                    error_handler.add_child(procs[device_id].pid)
                # end for
                producer = mp.Process(target=batch_producer,
                                      args=(
                                          train_iter,
                                          queues,
                                          semaphore,
                                          opt,
                                      ),
                                      daemon=True)
                producer.start()
                error_handler.add_child(producer.pid)

                for p in procs:
                    p.join()
                producer.terminate()

            elif nb_gpu == 1:  # case 1 GPU only
                self.train_single(output_model_dir, opt, 0)
            else:  # case only CPU
                self.train_single(output_model_dir, opt, -1)
            # end if
        # end with
        return
예제 #11
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # @Memray, check the dir existence beforehand to avoid path conflicting errors,
    #   and set save_model, tensorboard_log_dir, wandb_log_dir if not exist
    train_single._check_save_model_path(opt)
    if not os.path.exists(opt.tensorboard_log_dir):
        os.makedirs(opt.tensorboard_log_dir)

    # Scan previous checkpoint to resume training
    latest_step = 0
    latest_ckpt = None
    for subdir, dirs, filenames in os.walk(opt.exp_dir):
        for filename in sorted(filenames):
            if not filename.endswith('.pt'):
                continue
            step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')])
            if step > latest_step:
                latest_ckpt = os.path.join(subdir, filename)
                latest_step = step
    # if not saved in the exp folder, check opt.save_model
    if latest_ckpt is None and opt.save_model is not None:
        save_model_dir = os.path.dirname(os.path.abspath(opt.save_model))
        model_prefix = opt.save_model[opt.save_model.rfind(os.path.sep) + 1:]
        for subdir, dirs, filenames in os.walk(save_model_dir):
            for filename in sorted(filenames):
                if not filename.endswith('.pt'):
                    continue
                if not filename.startswith(model_prefix):
                    continue
                step = int(filename[filename.rfind('_') +
                                    1:filename.rfind('.pt')])
                if step > latest_step:
                    latest_ckpt = os.path.join(subdir, filename)
                    latest_step = step
    if latest_ckpt is not None:
        logger.info("A previous checkpoint is found, train from it: %s" %
                    latest_ckpt)
        setattr(opt, 'train_from', latest_ckpt)
        setattr(opt, 'reset_optim', 'none')

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    elif opt.vocab and opt.vocab != 'none':
        # added by @memray for multiple datasets
        vocab = torch.load(opt.vocab)
        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            vocab = load_old_vocab(vocab,
                                   opt.model_type,
                                   dynamic_dict=opt.copy_attn)
    elif opt.encoder_type == 'pretrained':
        vocab = None
    else:
        vocab = None

    fields = vocab

    # @memray: a temporary workaround, as well as train_single.py line 78
    if fields and opt.data_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(use_vocab=False,
                                    dtype=torch.long,
                                    postprocessing=make_tgt,
                                    sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    # @memray reload fields for news dataset and pretrained models
    tokenizer = None
    if opt.pretrained_tokenizer is not None:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer,
                                              opt.cache_dir,
                                              opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(opt, tokenizer=tokenizer)
    # elif opt.data_type == 'keyphrase':
    #     fields = reload_keyphrase_fields(opt, tokenizer=tokenizer)

    if len(opt.data_ids) > 1:
        # added by @memray, for loading multiple datasets
        if opt.multi_dataset:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base,
                                            fields,
                                            opt,
                                            multi=True)
        else:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
예제 #12
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    aux_vocab = None
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
        if opt.crosslingual:
            aux_vocab = checkpoint['aux_vocab']
    elif opt.crosslingual:
        assert opt.crosslingual in ['old', 'lm']
        vocab = torch.load(opt.data + '.vocab.pt')
        aux_vocab = torch.load(opt.aux_train_data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    def get_fields(vocab):
        if old_style_vocab(vocab):
            return load_old_vocab(vocab,
                                  opt.model_type,
                                  dynamic_dict=opt.copy_attn)
        else:
            return vocab

    fields = get_fields(vocab)
    aux_fields = None
    if opt.crosslingual:
        aux_fields = get_fields(aux_vocab)

    if opt.crosslingual:
        if opt.crosslingual == 'old':
            aeq(len(opt.eat_formats), 3)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data',
                 Eat2PlainCrosslingualTask, 'crosslingual', opt.eat_format[2])
            ]
        else:
            aeq(len(opt.eat_formats), 4)
            fields_info = [
                ('train', fields, 'data', Eat2PlainMonoTask, 'base',
                 opt.eat_formats[0]),
                ('train', fields, 'data', EatLMMonoTask, 'lm',
                 opt.eat_formats[1]),
                ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask,
                 'aux', opt.eat_formats[2]),
                ('train', aux_fields, 'aux_train_data', EatLMCrosslingualTask,
                 'crosslingual', opt.eat_formats[3])
            ]
        train_iter = build_crosslingual_dataset_iter(fields_info, opt)
    elif len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    else:
        device_id = 0 if nb_gpu == 1 else -1
        # NOTE Only pass train_iter in my crosslingual mode.
        train_iter = train_iter if opt.crosslingual else None
        passed_fields = {
            'main': fields,
            'crosslingual': aux_fields
        } if opt.crosslingual else None
        single_main(opt,
                    device_id,
                    train_iter=train_iter,
                    passed_fields=passed_fields)
예제 #13
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # @memray: a temporary workaround, as well as train_single.py line 78
    if opt.model_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(
                    use_vocab=False, dtype=torch.long,
                    postprocessing=make_tgt, sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)
    print(os.environ['PATH'])

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(mp.Process(target=run, args=(
                opt, device_id, error_queue, q, semaphore), daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(train_iter, queues, semaphore, opt,),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:   # case only CPU
        single_main(opt, -1)