示例#1
0
def main():
    # Parse arguments and print them
    args = parse_args()
    print("\nMain arguments:")
    for k, v in args.__dict__.items():
        print("{}={}".format(k, v))

    # Check if the model has already exisited
    model_save_dir = args.buckets + args.checkpoint_dir
    if tf.gfile.Exists(model_save_dir + "/checkpoint"):
        raise ValueError("Model %s has already existed, please delete them and retry" % model_save_dir)

    helper.dump_args(model_save_dir, args)

    convnet_model = model.SWEModel(
        model_configs=model.SWEModel.ModelConfigs(
            pooling = args.pooling,
            dropout=args.dropout,
            dim_word_embedding=args.word_dimension,
            init_with_w2v=False,
            hidden_layers=args.hidden_size,
            word_count=args.word_count,
            kid_count=args.kid_classes
        ),
        train_configs=model.TrainConfigs(
            learning_rate=args.learning_rate
        ),
        predict_configs=None,
        run_configs=model.RunConfigs(log_every=200)
    )

    estimator = tf.estimator.Estimator(
        model_fn=convnet_model.model_fn,
        model_dir=model_save_dir,
        config=tf.estimator.RunConfig(
            session_config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)),
            save_checkpoints_steps=args.snap_shot,
            keep_checkpoint_max=100
        )
    )

    print("Start training......")
    estimator.train(
        loader.OdpsDataLoader(
            table_name=args.tables,
            max_length=args.query_length,
            mode=args.mode
        ).input_fn,
        steps=args.max_steps
    )
def main():
    # Parse arguments and print them
    args = parse_args()
    print("\nMain arguments:")
    for k, v in args.__dict__.items():
        print("{}={}".format(k, v))

    # Check if the model has already exisited
    model_save_dir = args.checkpoint_dir
    if tf.gfile.Exists(model_save_dir + "/checkpoint"):
        raise ValueError(
            "Model %s has already existed, please delete them and retry" %
            model_save_dir)

    convnet_model = model.TextConvNet(
        model_configs=model.TextConvNet.ModelConfigs(
            kernels=args.kernels,
            dropout=args.dropout,
            dim_word_embedding=100,
            init_with_w2v=False,
            hidden_layers=args.hidden_size,
            word_count=178422),
        train_configs=model.TrainConfigs(learning_rate=args.learning_rate,
                                         batch_size=args.batch_size),
        predict_configs=None,
        run_configs=model.RunConfigs(log_every=200))

    estimator = tf.estimator.Estimator(
        model_fn=convnet_model.model_fn,
        model_dir=model_save_dir,
        config=tf.estimator.RunConfig(session_config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(allow_growth=True)),
                                      save_checkpoints_steps=args.snapshot,
                                      keep_checkpoint_max=20))

    print("Start training......")
    estimator.train(loader.LocalFileDataLoader(
        file_path=args.file_path,
        mode=tf.estimator.ModeKeys.TRAIN,
        hist_length=args.hist_length,
        target_length=args.target_length).input_fn,
                    steps=args.max_steps)
def main():
    # Parse arguments and print them
    args = parse_args()
    print("\nMain arguments:")
    for k, v in args.__dict__.items():
        print("{}={}".format(k, v))
    # tf.enable_eager_execution()
    transformer_model = model.TextTransformerNet(
        model_configs=model.TextTransformerNet.ModelConfigs(
            dropout_rate=args.dropout_rate,
            num_vocabulary=args.num_vocabulary,
            feed_forward_in_dim=args.feed_forward_in_dim,
            model_dim=args.model_dim,
            num_blocks=args.num_blocks,
            num_heads=args.num_heads,
            enable_date_time_emb=args.enable_date_time_emb,
            word_emb_dim=args.word_emb_dim,
            date_span=args.date_span),
        train_configs=model.TrainConfigs(learning_rate=args.learning_rate,
                                         batch_size=args.batch_size,
                                         dropout_rate=args.dropout_rate),
        predict_configs=None,
        run_configs=model.RunConfigs(log_every=10))

    estimator = tf.estimator.Estimator(
        model_fn=transformer_model.model_fn,
        model_dir=args.checkpoint_dir,
        config=tf.estimator.RunConfig(session_config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(allow_growth=True)),
                                      save_checkpoints_steps=args.snapshot,
                                      keep_checkpoint_max=20))

    print("Start training......")
    estimator.train(loader.LocalFileDataLoader(
        file_path=args.file_path,
        mode=tf.estimator.ModeKeys.TRAIN,
        hist_length=args.max_length,
        target_length=args.target_length).input_fn,
                    steps=args.max_steps)
def main():
    # Parse arguments and print them
    args = parse_args()
    print("\nMain arguments:")
    for k, v in args.__dict__.items():
        print("{}={}".format(k, v))

    config = parse_config('MiniBERT')

    # Setup distributed inference
    dist_params = {
        "task_index": args.task_index,
        "ps_hosts": args.ps_hosts,
        "worker_hosts": args.worker_hosts,
        "job_name": args.job_name
    }
    slice_count, slice_id = env.set_dist_env(dist_params)
    # Load model arguments
    model_save_dir = args.buckets + args.checkpoint_dir
    model_args = helper.load_args(model_save_dir)

    transformer_model = model.TextTransformerNet(
        bert_config=config,
        model_configs=model.TextTransformerNet.ModelConfigs(
            dropout_rate=model_args.dropout_rate,
            num_vocabulary = model_args.num_vocabulary,
            feed_forward_in_dim = model_args.feed_forward_in_dim,
            model_dim = model_args.model_dim,
            num_blocks = model_args.num_blocks,
            num_heads = model_args.num_heads,
            enable_date_time_emb = model_args.enable_date_time_emb,
            word_emb_dim=model_args.word_emb_dim,
            date_span=model_args.date_span
        ),
        train_configs=model.TrainConfigs(
            learning_rate=model_args.learning_rate,
            batch_size=model_args.batch_size,
            dropout_rate = model_args.dropout_rate
        ),
        predict_configs=model.PredictConfigs(
            separator=args.separator
        ),
        run_configs=model.RunConfigs(
            log_every=200
        )
    )


    estimator = tf.estimator.Estimator(
        model_fn=transformer_model.model_fn,
        model_dir=model_save_dir,
        config=tf.estimator.RunConfig(
            session_config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)),
            save_checkpoints_steps=model_args.max_steps,
            keep_checkpoint_max=1
        )
    )

    checkpoint_path = None
    if args.step > 0:
        checkpoint_path = model_save_dir + "model.ckpt-{}".format(args.step)

    result_iter = estimator.predict(
        loader.OdpsDataLoader(
            table_name=args.tables,
            mode=tf.estimator.ModeKeys.PREDICT,
            hist_length=model_args.max_length,
            target_length=model_args.target_length,
            batch_size=model_args.batch_size,
            slice_id=slice_id,
            slice_count=slice_count,
            shuffle=0,
            repeat=1
        ).input_fn,
        checkpoint_path=checkpoint_path
    )

    odps_writer = dumper.get_odps_writer(
        args.outputs,
        slice_id=slice_id
    )
    _do_prediction(result_iter, odps_writer, args, model_args)
示例#5
0
def main():
    # Parse arguments and print them
    args = parse_args()
    print("\nMain arguments:")
    for k, v in args.__dict__.items():
        print("{}={}".format(k, v))

    # Config
    config = parse_config('MiniBERT')
    config[
        "init_checkpoint"] = args.buckets + args.init_ckt_dir + "/model.ckpt-{}".format(
            args.init_ckt_step)

    # Check if the model has already exisited
    model_save_dir = args.buckets + args.checkpoint_dir
    if tf.gfile.Exists(model_save_dir + "/checkpoint"):
        raise ValueError(
            "Model %s has already existed, please delete them and retry" %
            model_save_dir)

    helper.dump_args(model_save_dir, args)

    transformer_model = model.TextTransformerNet(
        bert_config=config,
        model_configs=model.TextTransformerNet.ModelConfigs(
            dropout_rate=args.dropout_rate,
            num_vocabulary=args.num_vocabulary,
            feed_forward_in_dim=args.feed_forward_in_dim,
            model_dim=args.model_dim,
            num_blocks=args.num_blocks,
            num_heads=args.num_heads,
            enable_date_time_emb=args.enable_date_time_emb,
            word_emb_dim=args.word_emb_dim,
            date_span=args.date_span),
        train_configs=model.TrainConfigs(learning_rate=args.learning_rate,
                                         batch_size=args.batch_size,
                                         dropout_rate=args.dropout_rate),
        predict_configs=None,
        run_configs=model.RunConfigs(log_every=50))
    # checkpoint_path = None
    # if args.step > 0:
    #     checkpoint_path = model_save_dir + "/model.ckpt-{}".format(args.step)
    # warm_start_settings = tf.estimator.WarmStartSettings(checkpoint_path,
    #                                                      vars_to_warm_start='(.*Embedding|Conv-[1-4]|MlpLayer-1)')
    cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps('nccl')
    distribution = tf.contrib.distribute.MirroredStrategy(
        num_gpus=4, cross_tower_ops=cross_tower_ops, all_dense=False)

    estimator = tf.estimator.Estimator(
        model_fn=transformer_model.model_fn,
        model_dir=model_save_dir,
        config=tf.estimator.RunConfig(session_config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(allow_growth=False),
            allow_soft_placement=True),
                                      save_checkpoints_steps=args.snapshot,
                                      keep_checkpoint_max=20,
                                      train_distribute=distribution))
    print("Start training......")
    tf.estimator.train(estimator,
                       train_spec=tf.estimator.TrainSpec(
                           input_fn=loader.OdpsDataLoader(
                               table_name=args.tables,
                               mode=tf.estimator.ModeKeys.TRAIN,
                               hist_length=args.max_length,
                               target_length=args.target_length,
                               batch_size=args.batch_size).input_fn,
                           max_steps=args.max_steps))