Ejemplo n.º 1
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        mp = torch.multiprocessing.get_context('spawn')
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            procs.append(
                mp.Process(target=run,
                           args=(
                               opt,
                               device_id,
                               error_queue,
                           ),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        for p in procs:
            p.join()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
Ejemplo n.º 2
0
    def _save(self, step):
        real_model = (self.model.module if isinstance(
            self.model, nn.DataParallel) else self.model)
        if hasattr(self.model_opt, 'joint') and self.model_opt.joint:
            real_generator1 = (real_model.generator1.module if isinstance(
                real_model.generator1, nn.DataParallel) else
                               real_model.generator1)

            generator1_state_dict = real_generator1.state_dict()

            real_generator2 = (real_model.generator2.module if isinstance(
                real_model.generator1, nn.DataParallel) else
                               real_model.generator2)

            generator2_state_dict = real_generator2.state_dict()

            model_state_dict = real_model.state_dict()
            model_state_dict = {
                k: v
                for k, v in model_state_dict.items()
                if 'generator1' not in k or 'generator2' not in k
            }
            checkpoint = {
                'model': model_state_dict,
                'generator1': generator1_state_dict,
                'generator2': generator2_state_dict,
                'vocab': self.fields,
                'opt': self.model_opt,
                'optim': self.optim,
            }
        else:
            real_generator = (real_model.generator.module if isinstance(
                real_model.generator, nn.DataParallel) else
                              real_model.generator)

            model_state_dict = real_model.state_dict()
            model_state_dict = {
                k: v
                for k, v in model_state_dict.items() if 'generator' not in k
            }
            generator_state_dict = real_generator.state_dict()
            checkpoint = {
                'model': model_state_dict,
                'generator': generator_state_dict,
                'vocab': self.fields,
                'opt': self.model_opt,
                'optim': self.optim,
            }

        logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
        checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
        torch.save(checkpoint, checkpoint_path)
        return checkpoint, checkpoint_path
Ejemplo n.º 3
0
 def validate_train_opts(cls, opt):
     if opt.epochs:
         raise AssertionError(
             "-epochs is deprecated please use -train_steps.")
     if opt.truncated_decoder > 0 and opt.accum_count > 1:
         raise AssertionError("BPTT is not compatible with -accum > 1")
     if opt.gpuid:
         raise AssertionError("gpuid is deprecated \
               see world_size and gpu_ranks")
     if torch.cuda.is_available() and not opt.gpu_ranks:
         logger.info("WARNING: You have a CUDA device, \
                     should run with -gpu_ranks")
Ejemplo n.º 4
0
def load_vocabulary(vocab_path, tag):
    """
	Loads a vocabulary from the given path.
	:param vocabulary_path: path to load vocabulary from
	:param tag: tag for vocabulary (only used for logging)
	:return: vocabulary or None if path is null
	"""
    logger.info("Loading {} vocabulary from {}".format(tag, vocab_path))

    if not os.path.exists(vocab_path):
        raise RuntimeError("{} vocabulary not found at {}".format(
            tag, vocab_path))
    else:
        with codecs.open(vocab_path, 'r', 'utf-8') as f:
            return [line.strip().split()[0] for line in f if line.strip()]
Ejemplo n.º 5
0
    def output(self, step, num_steps, learning_rate, start):
        """Write out statistics to stdout.

        Args:
           step (int): current step
           n_batch (int): total batches
           start (int): start time of step.
        """
        t = self.elapsed_time()
        step_fmt = "%2d" % step
        if num_steps > 0:
            step_fmt = "%s/%5d" % (step_fmt, num_steps)
        logger.info(
            ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
             "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") %
            (step_fmt, self.accuracy(), self.ppl(), self.xent(), learning_rate,
             self.n_src_words / (t + 1e-5), self.n_words /
             (t + 1e-5), time.time() - start))
        sys.stdout.flush()
Ejemplo n.º 6
0
def build_model(model_opt, opt, fields, checkpoint):
    logger.info('Building model...')
    if opt.joint:
        model = build_joint_model(model_opt, fields, use_gpu(opt), checkpoint)
        logger.info(model)
        return model
    else:
        model, emb_word_lut = build_base_model(model_opt, fields, use_gpu(opt),
                                               checkpoint)
        logger.info(model)
        model2 = build_base_model2(model_opt,
                                   fields,
                                   use_gpu(opt),
                                   checkpoint,
                                   prev_emb_w=emb_word_lut)
        logger.info(model2)
        return model, model2
Ejemplo n.º 7
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file, from_scratch=opt.from_scratch)
    os.system('cp %s %s' % (opt.config, os.path.dirname(opt.log_file) + '/'))

    logger.warning(opt.description)
    logger.warning('Joint learning' if opt.joint else 'Pipeline learning')
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        logger.warning('Using old style vocab')
        fields = load_old_vocab(vocab)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tmpl', 'src2', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.

    if not opt.joint:
        _check_save_model_path(opt.save_model1)
        _check_save_model_path(opt.save_model2)

        model, model2 = build_model(model_opt, opt, fields, checkpoint)

        n_params, enc, dec = _tally_parameters(model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)

        n_params2, enc2, dec2 = _tally_parameters(model2)
        logger.info('encoder: %d' % enc2)
        logger.info('decoder: %d' % dec2)
        logger.info('* number of parameters: %d' % n_params2)

        # Build optimizer.
        optim = build_optim(model, opt, checkpoint)
        optim2 = build_optim(model2, opt, checkpoint)

        # Build model saver
        model1_saver = build_model_saver(model_opt,
                                         opt,
                                         model,
                                         fields,
                                         optim,
                                         save_path=opt.save_model1)
        model2_saver = build_model_saver(model_opt,
                                         opt,
                                         model2,
                                         fields,
                                         optim,
                                         save_path=opt.save_model2)

    else:
        assert opt.save_model1[:-2] == opt.save_model2[:-2]
        _save_model = opt.save_model1[:-2]
        _check_save_model_path(_save_model)

        model = build_model(model_opt, opt, fields, checkpoint)
        n_params, enc, dec = _tally_parameters(model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)

        optim = build_optim(model, opt, checkpoint)

        model1_saver = build_model_saver(model_opt,
                                         opt,
                                         model,
                                         fields,
                                         optim,
                                         save_path=_save_model)
        model2 = None
        optim2 = None
        model2_saver = None

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            model2,
                            fields,
                            optim,
                            optim2,
                            model1_saver=model1_saver,
                            model2_saver=model2_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps

    trainer.train(train_iter,
                  train_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Ejemplo n.º 8
0
    def train(self,
              train_iter,
              train_steps,
              valid_iter=None,
              valid_steps=10000):
        """
		The main training loop by iterating over `train_iter` and possibly
		running validation on `valid_iter`.

		Args:
		    train_iter: A generator that returns the next training batch.
		    train_steps: Run training for this many iterations.
		    valid_iter: A generator that returns the next validation batch.
		    valid_steps: Run evaluation every this many iterations.

		Returns:
		    The gathered statistics.
		"""
        if valid_iter is None:
            logger.info('Start training loop without validation...')
        else:
            logger.info('Start training loop and validate every %d steps...',
                        valid_steps)

        step = self.optim1._step + 1
        true_batchs = []
        accum = 0
        normalization1 = 0
        normalization2 = 0

        total_stats = monmt.utils.Statistics()
        report_stats = monmt.utils.Statistics()
        total_stats2 = monmt.utils.Statistics()
        report_stats2 = monmt.utils.Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):
                    if self.gpu_verbose_level > 1:
                        logger.info("GpuRank %d: index: %d accum: %d" %
                                    (self.gpu_rank, i, accum))

                    true_batchs.append(batch)

                    if self.norm_method == "tokens":
                        num_tokens = batch.tmpl[1:].ne(
                            self.train_loss1.padding_idx).sum()
                        normalization1 += num_tokens.item()
                    else:
                        normalization1 += batch.batch_size

                    if self.norm_method == "tokens":
                        num_tokens = batch.tgt[1:].ne(
                            self.train_loss2.padding_idx).sum()
                        normalization2 += num_tokens.item()
                    else:
                        normalization2 += batch.batch_size

                    accum += 1
                    if accum == self.grad_accum_count:
                        reduce_counter += 1
                        if self.gpu_verbose_level > 0:
                            logger.info("GpuRank %d: reduce_counter: %d \
                                        n_minibatch %d" %
                                        (self.gpu_rank, reduce_counter,
                                         len(true_batchs)))
                        if self.n_gpu > 1:
                            normalization1 = sum(
                                monmt.utils.distributed.all_gather_list(
                                    normalization1))
                            normalization2 = sum(
                                monmt.utils.distributed.all_gather_list(
                                    normalization2))
                        if self.joint:
                            self._gradient_accumulation_joint(
                                true_batchs, normalization1, total_stats,
                                report_stats, normalization2, total_stats2,
                                report_stats2)
                        else:
                            self._gradient_accumulation(
                                true_batchs, normalization1, total_stats,
                                report_stats, normalization2, total_stats2,
                                report_stats2)

                        report_stats = self._maybe_report_training(
                            step, train_steps, self.optim1.learning_rate,
                            report_stats)

                        if not self.joint:
                            report_stats2 = self._maybe_report_training(
                                step, train_steps, self.optim2.learning_rate,
                                report_stats2)
                        else:
                            report_stats2 = self._maybe_report_training(
                                step, train_steps, self.optim1.learning_rate,
                                report_stats2)

                        true_batchs = []
                        accum = 0
                        normalization1 = 0
                        normalization2 = 0
                        if (step % valid_steps == 0):
                            if self.gpu_verbose_level > 0:
                                logger.info('GpuRank %d: validate step %d' %
                                            (self.gpu_rank, step))

                            if self.joint:
                                valid_stats1, valid_stats2 = self.validate_joint(
                                    valid_iter)
                            else:
                                valid_stats1, valid_stats2 = self.validate(
                                    valid_iter)

                            if self.gpu_verbose_level > 0:
                                logger.info('GpuRank %d: gather valid stat \
                                            step %d' % (self.gpu_rank, step))
                            valid_stats1 = self._maybe_gather_stats(
                                valid_stats1)

                            if self.gpu_verbose_level > 0:
                                logger.info('GpuRank %d: report stat step %d' %
                                            (self.gpu_rank, step))

                            self._report_step(self.optim1.learning_rate,
                                              step,
                                              valid_stats=valid_stats1)

                            valid_stats2 = self._maybe_gather_stats(
                                valid_stats2)
                            if self.joint:
                                self._report_step(self.optim1.learning_rate,
                                                  step,
                                                  valid_stats=valid_stats2)
                            else:
                                self._report_step(self.optim2.learning_rate,
                                                  step,
                                                  valid_stats=valid_stats2)

                        if self.gpu_rank == 0:
                            self._maybe_save(step)
                        step += 1
                        if step > train_steps:
                            break

            if self.gpu_verbose_level > 0:
                logger.info('GpuRank %d: we completed an epoch \
                            at step %d' % (self.gpu_rank, step))

        return total_stats
Ejemplo n.º 9
0
def build_vocab(train_dataset_files, fields, share_vocab, src_vocab_path,
                src_vocab_size, src_words_min_frequency, tgt_vocab_path,
                tgt_vocab_size, tgt_words_min_frequency):
    """
	Args:
		train_dataset_files: a list of train dataset pt file.
		fields (dict): fields to build vocab for.
		data_type: "text", "img" or "audio"?
		share_vocab(bool): share source and target vocabulary?
		src_vocab_path(string): Path to src vocabulary file.
		src_vocab_size(int): size of the source vocabulary.
		src_words_min_frequency(int): the minimum frequency needed to
				include a source word in the vocabulary.
		tgt_vocab_path(string): Path to tgt vocabulary file.
		tgt_vocab_size(int): size of the target vocabulary.
		tgt_words_min_frequency(int): the minimum frequency needed to
				include a target word in the vocabulary.

	Returns:
		Dict of Fields
	"""
    # Prop src from field to get lower memory using when training with image
    counters = {k: Counter() for k in fields}

    # Load vocabulary
    if src_vocab_path:
        src_vocab = load_vocabulary(src_vocab_path, "src")
        src_vocab_size = len(src_vocab)
        logger.info('Loaded source vocab has %d tokens.' % src_vocab_size)
        for i, token in enumerate(src_vocab):
            # keep the order of tokens specified in the vocab file by
            # adding them to the counter with decreasing counting values
            counters['src'][token] = src_vocab_size - i
    else:
        src_vocab = None

    if tgt_vocab_path:
        tgt_vocab = load_vocabulary(tgt_vocab_path, "tgt")
        tgt_vocab_size = len(tgt_vocab)
        logger.info('Loaded source vocab has %d tokens.' % tgt_vocab_size)
        for i, token in enumerate(tgt_vocab):
            counters['tgt'][token] = tgt_vocab_size - i
    else:
        tgt_vocab = None

    for i, path in enumerate(train_dataset_files):
        dataset = torch.load(path)
        logger.info(" * reloading %s." % path)
        for ex in dataset.examples:
            for k in fields:
                has_vocab = (k == 'src' and src_vocab) or \
                            (k == 'tgt' and tgt_vocab)
                if fields[k].sequential and not has_vocab:
                    val = getattr(ex, k, None)
                    counters[k].update(val)

        # Drop the none-using from memory but keep the last
        if i < len(train_dataset_files) - 1:
            dataset.examples = None
            gc.collect()
            del dataset.examples
            gc.collect()
            del dataset
            gc.collect()

    _build_field_vocab(fields["tgt"],
                       counters["tgt"],
                       max_size=tgt_vocab_size,
                       min_freq=tgt_words_min_frequency)
    logger.info(" * tgt vocab size: %d." % len(fields["tgt"].vocab))

    # All datasets have same num of n_tgt_features,
    # getting the last one is OK.
    n_tgt_feats = sum('tgt_feat_' in k for k in fields)
    for j in range(n_tgt_feats):
        key = "tgt_feat_" + str(j)
        _build_field_vocab(fields[key], counters[key])
        logger.info(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

    _build_field_vocab(fields["src"],
                       counters["src"],
                       max_size=src_vocab_size,
                       min_freq=src_words_min_frequency)
    logger.info(" * src vocab size: %d." % len(fields["src"].vocab))

    # All datasets have same num of n_src_features,
    # getting the last one is OK.
    n_src_feats = sum('src_feat_' in k for k in fields)
    for j in range(n_src_feats):
        key = "src_feat_" + str(j)
        _build_field_vocab(fields[key], counters[key])
        logger.info(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

    _build_field_vocab(fields["tmpl"],
                       counters["tmpl"],
                       max_size=tgt_vocab_size,
                       min_freq=tgt_words_min_frequency)
    logger.info(" * tmpl vocab size: %d." % len(fields["tmpl"].vocab))

    n_tmpl_feats = sum('tmpl_feat_' in k for k in fields)
    for j in range(n_tmpl_feats):
        key = "tmpl_feat_" + str(j)
        _build_field_vocab(fields[key], counters[key])
        logger.info(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

    _build_field_vocab(fields["src2"],
                       counters["src2"],
                       max_size=src_vocab_size,
                       min_freq=src_words_min_frequency)
    logger.info(" * src2 vocab size: %d." % len(fields["src2"].vocab))

    n_src2_feats = sum('src2_feat_' in k for k in fields)
    for j in range(n_src2_feats):
        key = "src2_feat_" + str(j)
        _build_field_vocab(fields[key], counters[key])
        logger.info(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

    if share_vocab:
        # `tgt_vocab_size` is ignored when sharing vocabularies
        logger.info(" * merging src and tgt vocab...")
        _merge_field_vocabs(fields["src"],
                            fields["tgt"],
                            fields['tmpl'],
                            fields['src2'],
                            vocab_size=src_vocab_size,
                            min_freq=src_words_min_frequency)
        logger.info(" * merged vocab size: %d." % len(fields["src"].vocab))

    return fields
Ejemplo n.º 10
0
 def log(self, *args, **kwargs):
     logger.info(*args, **kwargs)