Пример #1
0
def get_transformer(encoder_num,
                    input_layer,
                    head_num,
                    hidden_dim,
                    attention_activation=None,
                    feed_forward_activation='relu',
                    dropout_rate=0.0,
                    trainable=True,
                    name=''):
    norm_layer = get_normalize_layer(input_layer, dropout_rate, trainable, name)
    last_layer = norm_layer
    for i in range(encoder_num):
        last_layer = get_encoder_component(
            name=name + '-Encoder-%d' % (i + 1),
            input_layer=last_layer,
            head_num=head_num,
            hidden_dim=hidden_dim,
            attention_activation=attention_activation,
            feed_forward_activation=feed_forward_activation,
            dropout_rate=dropout_rate,
            trainable=trainable,
        )
    last_layer = GlobalAveragePooling1D(name=name + 'Feature')(last_layer)
    return last_layer
def build_model(args):

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    K.set_session(tf.Session(config=config))

    if args.load_model:
        print("Loading previously saved model..")
        if args.bert_config:
            print("Warning: --bert_config ignored when loading previous Keras model.", file=sys.stderr)
        custom_objects = get_custom_objects()
        model = load_model(args.load_model, custom_objects=custom_objects)
    
    else:
        print("Building model..")
        bert = load_trained_model_from_checkpoint(args.bert_config, args.init_checkpoint,
                                                    training=False, trainable=True,
                                                    seq_len=args.seq_len)

        transformer_output = get_encoder_component(name="Encoder-13", input_layer=bert.layers[-1].output,
                                                head_num=12, hidden_dim=3072, feed_forward_activation=gelu)

        drop_mask = Lambda(lambda x: x, name="drop_mask")(bert.output)

        slice_CLS = Lambda(lambda x: K.slice(x, [0, 0, 0], [-1, 1, -1]), name="slice_CLS")(drop_mask)
        flatten_CLS = Flatten()(slice_CLS)

        # Needed to avoid a json serialization error when saving the model.
        last_position = args.seq_len-1
        slice_SEP = Lambda(lambda x: K.slice(x, [0, last_position, 0], [-1, 1, -1]), name="slice_SEP")(drop_mask)
        flatten_SEP = Flatten()(slice_SEP)

        permute_layer = Permute((2, 1))(drop_mask)
        permute_average = GlobalAveragePooling1D()(permute_layer)
        permute_maximum =  GlobalMaxPooling1D()(permute_layer)

        concat = Concatenate()([permute_average, permute_maximum, flatten_CLS, flatten_SEP])

        output_layer = Dense(get_label_dim(args.train), activation='sigmoid', name="label_out")(flatten_CLS)

        model = Model(bert.input, output_layer)
        
        total_steps, warmup_steps =  calc_train_steps(num_example=get_example_count(args.train),
                                                    batch_size=args.batch_size, epochs=args.epochs,
                                                    warmup_proportion=0.01)

        # optimizer = AdamWarmup(total_steps, warmup_steps, lr=args.lr)
        optimizer = keras.optimizers.Adam(lr=args.lr)

        model.compile(loss=["binary_crossentropy"], optimizer=optimizer, metrics=[])

    if args.gpus > 1:
        template_model = model
        # Set cpu_merge=False for better performance on NVLink connected GPUs.
        model = multi_gpu_model(template_model, gpus=args.gpus, cpu_merge=False)
        # TODO: need to compile this model as well when doing multigpu!

    callbacks = [Metrics(model)]

    if args.patience > -1:
        callbacks.append(EarlyStopping(patience=args.patience, verbose=1))

    if args.checkpoint_interval > 0:
        callbacks.append(ModelCheckpoint(args.output_file + ".checkpoint-{epoch}",  period=args.checkpoint_interval))


    print(model.summary(line_length=118))
    print("Number of GPUs in use:", args.gpus)
    print("Batch size:", args.batch_size)
    print("Learning rate:", K.eval(model.optimizer.lr))
    # print("Dropout:", args.dropout)

    model.fit_generator(data_generator(args.train, args.batch_size, seq_len=args.seq_len),
                        steps_per_epoch=ceil( get_example_count(args.train) / args.batch_size ),
                        use_multiprocessing=True, epochs=args.epochs, callbacks=callbacks,
                        validation_data=data_generator(args.dev, args.eval_batch_size, seq_len=args.seq_len),
                        validation_steps=ceil( get_example_count(args.dev) / args.eval_batch_size ))

    print("Saving model:", args.output_file)
    if args.gpus > 1:
        template_model.save(args.output_file)
    else:
        model.save(args.output_file)