示例#1
0
def finetune(args):
    # Load Paddlehub pretrained model, default as mobilenet
    module = hub.Module(name=args.module)
    input_dict, output_dict, program = module.context(trainable=True)

    # Download dataset and use ImageClassificationReader to read dataset
    dataset = hub.dataset.Flowers()
    data_reader = hub.reader.ImageClassificationReader(
        image_width=module.get_expected_image_width(),
        image_height=module.get_expected_image_height(),
        images_mean=module.get_pretrained_images_mean(),
        images_std=module.get_pretrained_images_std(),
        dataset=dataset)

    # The last 2 layer of resnet_v2_101_imagenet network
    feature_map = output_dict["feature_map"]

    img = input_dict["image"]
    feed_list = [img.name]

    # Select finetune strategy, setup config and finetune
    strategy = hub.DefaultFinetuneStrategy(learning_rate=args.learning_rate)
    config = hub.RunConfig(
        use_cuda=True,
        num_epoch=args.epochs,
        batch_size=args.batch_size,
        checkpoint_dir=args.checkpoint_dir,
        strategy=strategy)

    # Construct transfer learning network
    task = hub.ImageClassifierTask(
        data_reader=data_reader,
        feed_list=feed_list,
        feature=feature_map,
        num_classes=dataset.num_labels,
        config=config)

    # Load model from the defined model path or not
    if args.model_path != "":
        with task.phase_guard(phase="train"):
            task.init_if_necessary()
            task.load_parameters(args.model_path)
            logger.info("PaddleHub has loaded model from %s" % args.model_path)

    # Finetune by PaddleHub's API
    task.finetune()
    # Evaluate by PaddleHub's API
    run_states = task.eval()
    # Get acc score on dev
    eval_avg_score, eval_avg_loss, eval_run_speed = task._calculate_metrics(
        run_states)

    # Move ckpt/best_model to the defined saved parameters directory
    best_model_dir = os.path.join(config.checkpoint_dir, "best_model")
    if is_path_valid(args.saved_params_dir) and os.path.exists(best_model_dir):
        shutil.copytree(best_model_dir, args.saved_params_dir)
        shutil.rmtree(config.checkpoint_dir)

    # acc on dev will be used by auto finetune
    hub.report_final_result(eval_avg_score["acc"])
示例#2
0
TransformerModule.context = TransformerModule_pat.context

module = hub.Module(name="chinese-roberta-wwm-ext-large")
inputs, outputs, program = module.context(trainable=True, max_seq_len=128)
program.random_seed = 1


reader = hub.reader.ClassifyReader(
    dataset=dataset,
    vocab_path=module.get_vocab_path(),
    max_seq_len=128,
    random_seed=1)

print("learning rate: ", eval(args.lr))
print("max epoch: ", args.max_epoch)
strategy = hub.DefaultFinetuneStrategy(learning_rate=eval(args.lr), optimizer_name="sgd")

config = hub.RunConfig(use_cuda=True, num_epoch=args.max_epoch, batch_size=32, strategy=strategy, log_interval=100,
                 eval_interval=1400,save_ckpt_interval=1400, checkpoint_dir='./checkpoint_aug')
                 

pooled_output = outputs["pooled_output"]

feed_list = [
    inputs["input_ids"].name,
    inputs["position_ids"].name,
    inputs["segment_ids"].name,
    inputs["input_mask"].name
]

cls_task = hub.TextClassifierTask(
示例#3
0
                                      predict_file_with_header=False,
                                      label_list=["Yes", "Depends", "No"])


dataset = Dataset()

module = hub.Module(name="chinese-roberta-wwm-ext-large")
inputs, outputs, program = module.context(trainable=False, max_seq_len=128)
program.random_seed = 1
print(program.random_seed)

reader = hub.reader.ClassifyReader(dataset=dataset,
                                   vocab_path=module.get_vocab_path(),
                                   max_seq_len=128)

strategy = hub.DefaultFinetuneStrategy(learning_rate=0.002,
                                       optimizer_name="sgd")

config = hub.RunConfig(use_cuda=True,
                       num_epoch=3,
                       batch_size=32,
                       strategy=strategy,
                       log_interval=100,
                       eval_interval=1400,
                       save_ckpt_interval=1400,
                       checkpoint_dir='./checkpoint_aug')

pooled_output = outputs["pooled_output"]

feed_list = [
    inputs["input_ids"].name, inputs["position_ids"].name,
    inputs["segment_ids"].name, inputs["input_mask"].name