예제 #1
0
파일: link_predict.py 프로젝트: Yelrose/PGL
def train(config):
    # Build Train Data
    data = TrainData(config.graph_work_path)
    train_iter = BatchGraphGenerator(graph_wrappers=[1],
                                     batch_size=config.batch_size,
                                     data=data,
                                     samples=config.samples,
                                     num_workers=config.sample_workers,
                                     feed_name_list=None,
                                     use_pyreader=False,
                                     phase="train",
                                     graph_data_path=config.graph_work_path,
                                     shuffle=True,
                                     neg_type=config.neg_type)
    train_ds = Dataset.from_generator_func(train_iter).repeat(config.epochs)
    dev_ds = Dataset.from_generator_func(train_iter)

    ernie_cfg_dict, ernie_param_path = PretrainedModelLoader.from_pretrained(
        config.ernie_name)

    if "warm_start_from" not in config:
        warm_start_from = ernie_param_path
    else:
        ernie_param_path = config.ernie_param_path

    if "ernie_config" not in config:
        config.ernie_config = ernie_cfg_dict

    ws = propeller.WarmStartSetting(predicate_fn=lambda v: os.path.exists(
        os.path.join(warm_start_from, v.name)),
                                    from_dir=warm_start_from)

    train_ds.name = "train"
    train_ds.data_shapes = [[-1] + list(shape[1:])
                            for shape in train_ds.data_shapes]
    dev_ds.name = "dev"
    dev_ds.data_shapes = [[-1] + list(shape[1:])
                          for shape in dev_ds.data_shapes]

    tokenizer = load_tokenizer(config.ernie_name)
    config.cls_id = tokenizer.cls_id

    propeller.train.train_and_eval(
        model_class_or_model_fn=ERNIESageLinkPredictModel,
        params=config,
        run_config=config,
        train_dataset=train_ds,
        eval_dataset={"eval": dev_ds},
        warm_start_setting=ws,
    )
예제 #2
0
                                       .padded_batch(hparams.batch_size) \

        train_ds.data_shapes = shapes
        train_ds.data_types = types
        dev_ds.data_shapes = shapes
        dev_ds.data_types = types
        test_ds.data_shapes = shapes
        test_ds.data_types = types

        varname_to_warmstart = re.compile(
            r'^encoder.*[wb]_0$|^.*embedding$|^.*bias$|^.*scale$|^pooled_fc.[wb]_0$'
        )

        ws = propeller.WarmStartSetting(
            predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.
            path.exists(os.path.join(param_path, v.name)),
            from_dir=param_path,
        )

        best_exporter = propeller.train.exporter.BestExporter(
            os.path.join(run_config.model_dir, 'best'),
            cmp_fn=lambda old, new: new['dev']['acc'] > old['dev']['acc'])
        propeller.train.train_and_eval(model_class_or_model_fn=model_fn,
                                       params=hparams,
                                       run_config=run_config,
                                       train_dataset=train_ds,
                                       eval_dataset={
                                           'dev': dev_ds,
                                           'test': test_ds
                                       },
                                       warm_start_setting=ws,
예제 #3
0
def train(args, pretrained_model_config=None):
    log.info("loading data")
    raw_dataset = GraphPropPredDataset(name=args.dataset_name)
    args.num_class = raw_dataset.num_tasks
    args.eval_metric = raw_dataset.eval_metric
    args.task_type = raw_dataset.task_type

    train_ds = MolDataset(args, raw_dataset)

    args.eval_steps = math.ceil(len(train_ds) / args.batch_size)
    log.info("Total %s steps (eval_steps) every epoch." % (args.eval_steps))

    fn = MgfCollateFn(args)

    train_loader = Dataloader(train_ds,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=args.shuffle,
                              stream_shuffle_size=args.shuffle_size,
                              collate_fn=fn)

    # for evaluating
    eval_train_loader = train_loader
    eval_train_loader = PDataset.from_generator_func(eval_train_loader)

    train_loader = multi_epoch_dataloader(train_loader, args.epochs)
    train_loader = PDataset.from_generator_func(train_loader)

    if args.warm_start_from is not None:
        # warm start setting
        def _fn(v):
            if not isinstance(v, F.framework.Parameter):
                return False
            if os.path.exists(os.path.join(args.warm_start_from, v.name)):
                return True
            else:
                return False

        ws = propeller.WarmStartSetting(predicate_fn=_fn,
                                        from_dir=args.warm_start_from)
    else:
        ws = None

    def cmp_fn(old, new):
        if old['eval'][args.metrics] - new['eval'][args.metrics] > 0:
            log.info("best %s eval result: %s" % (args.metrics, new['eval']))
            return True
        else:
            return False

    if args.log_id is not None:
        save_best_model = int(args.log_id) == 5
    else:
        save_best_model = True
    best_exporter = propeller.exporter.BestResultExporter(
        args.output_dir, (cmp_fn, save_best_model))

    eval_datasets = {"eval": eval_train_loader}

    propeller.train.train_and_eval(
        model_class_or_model_fn=MgfModel,
        params=pretrained_model_config,
        run_config=args,
        train_dataset=train_loader,
        eval_dataset=eval_datasets,
        warm_start_setting=ws,
        exporters=[best_exporter],
    )