def __init__(self,
                 dataset,
                 hparams,
                 collate_fn=None,
                 sampler=None,
                 is_test=False,
                 is_train=False):
        self.dataset = dataset
        self.collate_fn = collate_fn
        self.sort_pool_size = hparams.sort_pool_size

        if sampler is None:
            if hparams.shuffle and not is_test:
                sampler = RandomSampler(dataset)
            else:
                sampler = SequentialSampler(dataset)

        if self.sort_pool_size > 0 and not is_test:
            sampler = SortedSampler(sampler, self.sort_pool_size)

        def reader():
            for idx in sampler:
                yield idx

        self.reader = paddle.batch(reader,
                                   batch_size=hparams.batch_size,
                                   drop_last=False)
        self.num_batches = math.ceil(len(dataset) / hparams.batch_size)

        if hparams.use_data_distributed and parallel.Env(
        ).nranks > 1 and is_train:
            self.reader = fluid.contrib.reader.distributed_batch_reader(
                self.reader)
            self.num_batches = self.num_batches // fluid.dygraph.parallel.Env(
            ).nranks

        return
예제 #2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--do_train",
                        type=str2bool,
                        default=False,
                        help="Whether to run trainning.")
    parser.add_argument("--do_test",
                        type=str2bool,
                        default=False,
                        help="Whether to run evaluation on the test dataset.")
    parser.add_argument("--do_infer",
                        type=str2bool,
                        default=False,
                        help="Whether to run inference on the test dataset.")
    parser.add_argument("--num_infer_batches",
                        type=int,
                        default=None,
                        help="The number of batches need to infer.\n"
                        "Stay 'None': infer on entrie test dataset.")
    parser.add_argument(
        "--hparams_file",
        type=str,
        default=None,
        help="Loading hparams setting from file(.json format).")
    BPETextField.add_cmdline_argument(parser)
    Dataset.add_cmdline_argument(parser)
    Trainer.add_cmdline_argument(parser)
    ModelBase.add_cmdline_argument(parser)
    Generator.add_cmdline_argument(parser)

    hparams = parse_args(parser)

    if hparams.hparams_file and os.path.exists(hparams.hparams_file):
        print(f"Loading hparams from {hparams.hparams_file} ...")
        hparams.load(hparams.hparams_file)
        print(f"Loaded hparams from {hparams.hparams_file}")

    print(json.dumps(hparams, indent=2))

    if not os.path.exists(hparams.save_dir):
        os.makedirs(hparams.save_dir)
    hparams.save(os.path.join(hparams.save_dir, "hparams.json"))

    bpe = BPETextField(hparams.BPETextField)
    hparams.Model.num_token_embeddings = bpe.vocab_size

    generator = Generator.create(hparams.Generator, bpe=bpe)

    COLLATE_FN = {
        "multi": bpe.collate_fn_multi_turn,
        "multi_knowledge": bpe.collate_fn_multi_turn_with_knowledge
    }
    collate_fn = COLLATE_FN[hparams.data_type]

    # Loading datasets
    if hparams.do_train:
        raw_train_file = os.path.join(hparams.data_dir, "dial.train")
        train_file = raw_train_file + f".{hparams.tokenizer_type}.jsonl"
        assert os.path.exists(train_file), f"{train_file} isn't exist"
        train_dataset = LazyDataset(train_file)
        train_loader = DataLoader(train_dataset,
                                  hparams.Trainer,
                                  collate_fn=collate_fn,
                                  is_train=True)
        raw_valid_file = os.path.join(hparams.data_dir, "dial.valid")
        valid_file = raw_valid_file + f".{hparams.tokenizer_type}.jsonl"
        assert os.path.exists(valid_file), f"{valid_file} isn't exist"
        valid_dataset = LazyDataset(valid_file)
        valid_loader = DataLoader(valid_dataset,
                                  hparams.Trainer,
                                  collate_fn=collate_fn)

    if hparams.do_infer or hparams.do_test:
        raw_test_file = os.path.join(hparams.data_dir, "dial.test")
        test_file = raw_test_file + f".{hparams.tokenizer_type}.jsonl"
        assert os.path.exists(test_file), f"{test_file} isn't exist"
        test_dataset = LazyDataset(test_file)
        test_loader = DataLoader(test_dataset,
                                 hparams.Trainer,
                                 collate_fn=collate_fn,
                                 is_test=hparams.do_infer)

    def to_tensor(array):
        return fluid.dygraph.to_variable(array)

    if hparams.use_data_distributed:
        place = fluid.CUDAPlace(parallel.Env().dev_id)
    else:
        place = fluid.CUDAPlace(0)
        # place = fluid.CPUPlace()
        # 改了也没用,不支持CPU 2020-04-23 20:50:39

    with fluid.dygraph.guard(place):
        # Construct Model
        model = ModelBase.create("Model", hparams, generator=generator)

        # Construct Trainer
        trainer = Trainer(model, to_tensor, hparams.Trainer)

        if hparams.do_train:
            # Training process
            for epoch in range(hparams.num_epochs):
                trainer.train_epoch(train_loader, valid_loader)

        if hparams.do_test:
            # Validation process
            trainer.evaluate(test_loader, need_save=False)

        if hparams.do_infer:
            # Inference process
            def split(xs, sep, pad):
                """ Split id list by separator. """
                out, o = [], []
                for x in xs:
                    if x == pad:
                        continue
                    if x != sep:
                        o.append(x)
                    else:
                        if len(o) > 0:
                            out.append(list(o))
                            o = []
                if len(o) > 0:
                    out.append(list(o))
                assert (all(len(o) > 0 for o in out))
                return out

            def parse_context(batch):
                """ Parse context. """
                return bpe.denumericalize([
                    split(xs, bpe.eos_id, bpe.pad_id) for xs in batch.tolist()
                ])

            def parse_text(batch):
                """ Parse text. """
                return bpe.denumericalize(batch.tolist())

            infer_parse_dict = {
                "src": parse_context,
                "tgt": parse_text,
                "preds": parse_text
            }
            trainer.infer(test_loader,
                          infer_parse_dict,
                          num_batches=hparams.num_infer_batches)
    def evaluate(self, data_iter, need_save=True):
        """
        Evaluation interface

        @param : data_iter
        @type : DataLoader

        @param : need_save
        @type : bool
        """
        if isinstance(self.model, parallel.DataParallel):
            need_save = need_save and parallel.Env().local_rank == 0

        # Evaluation
        begin_time = time.time()
        batch_metrics_tracker = MetricsTracker()
        token_metrics_tracker = MetricsTracker()
        for batch, batch_size in data_iter:
            batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])),
                                    batch.items()))
            metrics = self.model(batch, is_training=False)
            token_num = int(metrics.pop("token_num"))
            batch_metrics = {
                k: v
                for k, v in metrics.items() if "token" not in k
            }
            token_metrics = {k: v for k, v in metrics.items() if "token" in k}
            batch_metrics_tracker.update(batch_metrics, batch_size)
            token_metrics_tracker.update(token_metrics, token_num)
        batch_metrics_message = batch_metrics_tracker.summary()
        token_metrics_message = token_metrics_tracker.summary()
        message_prefix = f"[Valid][{self.epoch}]"
        time_cost = f"TIME-{time.time() - begin_time:.3f}"
        message = "   ".join([
            message_prefix, batch_metrics_message, token_metrics_message,
            time_cost
        ])
        self.logger.info(message)

        if need_save:
            # Check valid metric
            cur_valid_metric = batch_metrics_tracker.get(
                self.valid_metric_name)
            if self.is_decreased_valid_metric:
                is_best = cur_valid_metric < self.best_valid_metric
            else:
                is_best = cur_valid_metric > self.best_valid_metric
            if is_best:
                # Save current best model
                self.best_valid_metric = cur_valid_metric
                best_model_path = os.path.join(self.save_dir, "best.model")
                save(self.model, best_model_path)
                self.logger.info(
                    f"Saved best model to '{best_model_path}' with new best valid metric "
                    f"{self.valid_metric_name.upper()}-{self.best_valid_metric:.3f}"
                )

            # Save checkpoint
            if self.save_checkpoint:
                model_file = os.path.join(self.save_dir,
                                          f"epoch_{self.epoch}.model")
                save(self.model, model_file)

            if self.save_summary:
                with self.summary_logger.mode("valid"):
                    for k, v in self.batch_metrics_tracker.items():
                        if k not in self.valid_summary:
                            self.valid_summary[k] = self.summary_logger.scalar(
                                k)
                        scalar = self.valid_summary[k]
                        scalar.add_record(self.batch_num, v)
                    for k, v in self.token_metrics_tracker.items():
                        if k not in self.valid_summary:
                            self.valid_summary[k] = self.summary_logger.scalar(
                                k)
                        scalar = self.valid_summary[k]
                        scalar.add_record(self.batch_num, v)

        return