Esempio n. 1
0
    def test_search_model(self):
        """Test code for configurtion search."""
        configuration_search_model, config = build_model_from_config(
            './bert_config.json',
            output_dim=2,
            seq_len=64,
            LAMBDA=3e-3,
            FLAG_EXTRACT_LAYER=1,
            TASK='cola')

        decay_steps, warmup_steps = calc_train_steps(
            8550,
            batch_size=128,
            epochs=3,
        )
        configuration_search_model.compile(
            AdamWarmup(decay_steps=decay_steps,
                       warmup_steps=warmup_steps,
                       lr=3e-5,
                       lr_mult=None),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])

        print("Configuration search model summary: ",
              configuration_search_model.summary())
        del configuration_search_model
Esempio n. 2
0
    def test_retrained_model(self):
        """Test code for retrained model."""
        retrained_model, config = build_model_from_config(
            './bert_config.json',
            output_dim=2,
            seq_len=64,
            retention_configuration=[
                64, 64, 64, 32, 32, 32, 16, 16, 16, 8, 8, 8
            ],
            FLAG_EXTRACT_LAYER=2,
            TASK='cola')

        decay_steps, warmup_steps = calc_train_steps(
            8550,
            batch_size=128,
            epochs=3,
        )
        retrained_model.compile(AdamWarmup(decay_steps=decay_steps,
                                           warmup_steps=warmup_steps,
                                           lr=3e-5,
                                           lr_mult=None),
                                loss='sparse_categorical_crossentropy',
                                metrics=['accuracy'])

        print("Retrained model summary: ", retrained_model.summary())
        del retrained_model
Esempio n. 3
0
    def test_finetuned_model(self):
        """Test code for finetuning task."""

        fine_tuned_model, config = build_model_from_config(
            './bert_config.json',
            output_dim=2,
            seq_len=64,
            FLAG_EXTRACT_LAYER=0,
            TASK='cola')

        decay_steps, warmup_steps = calc_train_steps(
            8550,
            batch_size=128,
            epochs=3,
        )

        fine_tuned_model.compile(AdamWarmup(decay_steps=decay_steps,
                                            warmup_steps=warmup_steps,
                                            lr=3e-5,
                                            lr_mult=None),
                                 loss='sparse_categorical_crossentropy',
                                 metrics=['accuracy'])

        print("Fine-tuned model summary: ", fine_tuned_model.summary())
        del fine_tuned_model
Esempio n. 4
0
def eval(args, dev_x, dev_y, num_layers, num_classes, seq_len):

    ## Parse the retention configuration if provided else use default values equal to the sequence length
    if args.RETENTION_CONFIG == None:
        retention_config = [seq_len] * num_layers
        warnings.warn("Retention Config not provided. Evaluation will be performed on default config : "+','.join(retention_config))
    else:
        retention_config = retention_config_parser(args.RETENTION_CONFIG, 
                                                   num_layers, 
                                                   seq_len)

    ## Model definition and Evaluation
    model, config = build_model_from_config(
                                args.BERT_CONFIG_PATH,
                                output_dim=num_classes,
                                seq_len=seq_len,
                                retention_configuration=retention_config,
                                FLAG_EXTRACT_LAYER=2,
                                TASK=args.TASK)

    if args.TASK == "sts-b":
        loss='mean_squared_error'
        metrics=[metric_cor, 'mae']
    else:
        loss='sparse_categorical_crossentropy'
        metrics=['accuracy']

    model.compile(
            AdamWarmup(decay_steps=1,
                        warmup_steps=1,
                        lr=0.00002,
                        lr_mult=None),
            loss=loss,
            metrics=metrics,
    )
    if args.MODEL_FORMAT == "CKPT":
        load_checkpoint(model,
                        config,
                        args.CHECKPOINT_PATH
        )
    elif args.MODEL_FORMAT == "HDF5":
        model.load_weights(args.CHECKPOINT_PATH, by_name=True)
    else:
        print ("Model format not supported")
        exit(-1)

    return model.evaluate(dev_x, 
                          dev_y, 
                          verbose=1, 
                          batch_size=args.BATCH_SIZE)