Пример #1
0
            )
            lang_loss, dec_output, encoder_hidden = return_dict.loss, return_dict.logits, return_dict.encoder_last_hidden_state

            tot_val_loss += lang_loss * len(inputs['input_ids'])
            n_val += len(inputs['input_ids'])

    print("n_val", n_val)
    avg_val_loss = tot_val_loss.item() / n_val
    return n_val, avg_val_loss


tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
if pretrained:
    model = BartForConditionalGeneration.from_pretrained('facebook/bart-base', dropout=args.dropout)
else:
    config = BartConfig.from_pretrained('facebook/bart-base')
    config.dropout = args.dropout
    model = BartForConditionalGeneration(config)
model.to(DEVICE)
optimizer = AdamW(list(model.parameters()), lr=args.lr)
print("Loaded model")

# TODO load data
dataset = load_data(args.data, ["walkthrough0"] + [f"randcmd{i}" for i in range(100)], tokenizer, max_seq_len, max_data_size=4000)
print("Loaded train data")
dev_dataset = load_data(args.data, [f"randcmd{i}" for i in range(100,200)], tokenizer, max_seq_len, max_data_size=500)
print("Loaded dev data")

# initial eval
print("Initial eval")
n_val, avg_val_loss = eval_model(args, model, dev_dataset, tokenizer, eval_batchsize)
Пример #2
0
print(f"number of valid examples: {len(valid_idxs)}")

train_loader = DataLoader(train_idxs, batch_size=args.train_batch_size, shuffle=True)
valid_loader = DataLoader(valid_idxs, batch_size=args.valid_batch_size, shuffle=False)

print("==== preparing data ====")
make_path(args.cache_dir)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)

with open('synt_vocab.pkl', 'rb') as f:
    synt_vocab = pickle.load(f)

dataset = prepare_dataset(para_data, tokenizer, num)

print("==== loading model ====")
config = BartConfig.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)
config.word_dropout = args.word_dropout
config.max_sent_len = args.max_sent_len
config.max_synt_len = args.max_synt_len

bart = BartModel.from_pretrained('facebook/bart-base', cache_dir=args.cache_dir)
model = ParaBart(config)
model.load_state_dict(bart.state_dict(), strict=False)
model.zero_grad()
del bart


no_decay_params = []
no_decay_fast_params = []
fast_params = []
all_other_params = []
Пример #3
0
def create_model(encoder_name="xlm-roberta-base",
                 dec_vocabsize=None,
                 dec_layers=6,
                 dec_dim=640,
                 dec_heads=8,
                 dropout=0.,
                 dropoutdec=0.,
                 maxlen=50,
                 smoothing=0.,
                 numbeam=1,
                 tensor2tree=None,
                 statesimweight=0.,
                 probsimweight=0.,
                 projmode="simple"):
    # if encoder_name != "bert-base-uncased":
    #     raise NotImplemented(f"encoder '{encoder_name}' not supported yet.")
    pretrained = AutoModel.from_pretrained(encoder_name)
    encoder = pretrained

    class BertEncoderWrapper(torch.nn.Module):
        def __init__(self, model, dropout=0., **kw):
            super(BertEncoderWrapper, self).__init__(**kw)
            self.model = model
            self.proj = self.create_proj(pretrained.config.hidden_size,
                                         dec_dim,
                                         mode=projmode)
            self.dropout = torch.nn.Dropout(dropout)

        @classmethod
        def create_proj(cls, indim, outdim, mode="simple"):
            if mode == "simple" or mode == "simpleshared":
                proj = torch.nn.Linear(indim, outdim, bias=False)
            elif mode == "twolayer":
                proj = torch.nn.Sequential(torch.nn.Linear(indim, indim * 4),
                                           torch.nn.LeakyReLU(0.1),
                                           torch.nn.Linear(indim * 4, outdim))
            return proj

        def forward(self, input_ids, attention_mask=None):
            ret, _ = self.model(input_ids, attention_mask=attention_mask)
            # if pretrained.config.hidden_size != dec_dim:
            ret = self.proj(ret)
            # ret = self.dropout(ret)
            ret = (ret, None, None)
            return ret

    encoder = BertEncoderWrapper(encoder, dropout=dropout)

    decoder_config = BartConfig(
        d_model=dec_dim,
        pad_token_id=0,
        bos_token_id=1,
        vocab_size=dec_vocabsize,
        decoder_attention_heads=dec_heads // 2,
        decoder_layers=dec_layers,
        dropout=dropoutdec,
        attention_dropout=min(0.1, dropout / 2),
        decoder_ffn_dim=dec_dim * 4,
        encoder_attention_heads=dec_heads,
        encoder_layers=dec_layers,
        encoder_ffn_dim=dec_dim * 4,
        relative_position=True,
    )
    model = BartGenerator(decoder_config, encoder.model.config)
    model.model.encoder = encoder
    if "shared" in projmode:
        model2 = model
    else:
        model2 = q.copy(model)
        model2.model = q.copy(model.model)
        model2.model.encoder = q.copy(model.model.encoder)
        model2.model.encoder.proj = BertEncoderWrapper.create_proj(
            pretrained.config.hidden_size, dec_dim, mode=projmode)

    orderless = {"op:and", "SW:concat"}

    trainmodel = BartGeneratorTrain(model,
                                    model2,
                                    smoothing=smoothing,
                                    tensor2tree=tensor2tree,
                                    orderless=orderless,
                                    statesimweight=statesimweight,
                                    probsimweight=probsimweight)
    testmodel = BartGeneratorTest(model,
                                  model2,
                                  maxlen=maxlen,
                                  numbeam=numbeam,
                                  tensor2tree=tensor2tree,
                                  orderless=orderless)
    return trainmodel, testmodel
 def test_xsum_config_generation_params(self):
     config = BartConfig.from_pretrained("facebook/bart-large-xsum")
     expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
     config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
     self.assertDictEqual(expected_params, config_params)
Пример #5
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses(
        )

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_bart_dlm",
                           model_args,
                           data_args,
                           framework="flax")

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        level=logging.INFO,
        datefmt="[%X]",
    )

    # Log on each process the small summary:
    logger = logging.getLogger(__name__)

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(Path(
                training_args.output_dir).absolute().name,
                                           token=training_args.hub_token)
        else:
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        if "validation" not in datasets.keys():
            datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
            datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                use_auth_token=True if model_args.use_auth_token else None,
            )
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Load pretrained model and tokenizer

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.config_name:
        config = BartConfig.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
            vocab_size=len(tokenizer),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    elif model_args.model_name_or_path:
        config = BartConfig.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning(
            "You are instantiating a new config instance from scratch.")

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = datasets["train"].column_names
    else:
        column_names = datasets["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    # Use Punkt Sentence Tokenizer to divide a document into a list of sentences
    nltk.download("punkt")
    sentence_tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")

    def sentence_split_function(example):
        sents = sentence_tokenizer.tokenize(example["text"])
        # use pad token as end of sentence indicator
        new_text = tokenizer.bos_token + f"{tokenizer.pad_token}".join(
            sents) + tokenizer.eos_token
        return {"text": new_text}

    split_datasets = datasets.map(
        sentence_split_function,
        batched=False,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Tokenize every text, then concatenate them together before splitting them in smaller parts.
    # Since we make sure that all sequences are of the same length, no attention_mask is needed.
    def tokenize_function(examples):
        return tokenizer(examples[text_column_name],
                         add_special_tokens=False,
                         return_attention_mask=False)

    tokenized_datasets = split_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=text_column_name,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of
    # max_seq_length.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: list(chain(*examples[k]))
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= max_seq_length:
            total_length = (total_length // max_seq_length) * max_seq_length
        # Split by chunks of max_len.
        result = {
            k: [
                t[i:i + max_seq_length]
                for i in range(0, total_length, max_seq_length)
            ]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
    # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
    # might be slower to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    tokenized_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(
                log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable.")

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())

    if model_args.model_name_or_path:
        model = FlaxBartForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
        )
    else:
        config.vocab_size = len(tokenizer)
        model = FlaxBartForConditionalGeneration(
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
        )

    # Data collator
    # This one will take care of randomly masking the tokens and permuting the sentences.
    data_collator = FlaxDataCollatorForBartDenoisingLM(
        tokenizer=tokenizer,
        decoder_start_token_id=model.config.decoder_start_token_id,
        mask_ratio=data_args.mlm_probability,
        poisson_lambda=data_args.poisson_lambda,
        permute_sentence_ratio=data_args.permute_sentence_ratio,
    )

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(
        training_args.per_device_train_batch_size) * jax.device_count()
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()

    num_train_steps = len(
        tokenized_datasets["train"]) // train_batch_size * num_epochs

    # Create learning rate schedule
    warmup_fn = optax.linear_schedule(
        init_value=0.0,
        end_value=training_args.learning_rate,
        transition_steps=training_args.warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=training_args.learning_rate,
        end_value=0,
        transition_steps=num_train_steps - training_args.warmup_steps,
    )
    linear_decay_lr_schedule_fn = optax.join_schedules(
        schedules=[warmup_fn, decay_fn],
        boundaries=[training_args.warmup_steps])

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
        layer_norm_named_params = set([
            layer[-2:] for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        ])
        flat_mask = {
            path: (path[-1] != "bias"
                   and path[-2:] not in layer_norm_named_params)
            for path in flat_params
        }
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn, )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )

    # Setup train state
    state = train_state.TrainState.create(apply_fn=model.__call__,
                                          params=model.params,
                                          tx=optimizer)

    # Define gradient update step fn
    def train_step(state, batch, dropout_rng):
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

        def loss_fn(params):
            labels = batch.pop("labels")

            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]

            # compute loss, ignore padded input tokens and special tokens
            label_mask = jnp.where(labels > 0, 1.0, 0.0)
            loss = optax.softmax_cross_entropy(
                logits, onehot(labels, logits.shape[-1])) * label_mask

            # take average
            loss = loss.sum() / label_mask.sum()

            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)

        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": linear_decay_lr_schedule_fn(state.step)
            },
            axis_name="batch")

        return new_state, metrics, new_dropout_rng

    # Create parallel version of the train step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, ))

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss, ignore padded input tokens and special tokens
        label_mask = jnp.where(labels > 0, 1.0, 0.0)
        loss = optax.softmax_cross_entropy(
            logits, onehot(labels, logits.shape[-1])) * label_mask

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask

        # summarize metrics
        metrics = {
            "loss": loss.sum(),
            "accuracy": accuracy.sum(),
            "normalizer": label_mask.sum()
        }
        metrics = jax.lax.psum(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, ))

    # Replicate the train state on each device
    state = jax_utils.replicate(state)

    train_time = 0
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
        # Avoid using jax.numpy here in case of TPU training
        train_samples_idx = np.random.permutation(np.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                tokenized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            model_inputs = shard(model_inputs.data)
            state, train_metric, dropout_rngs = p_train_step(
                state, model_inputs, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = epoch * (num_train_samples // train_batch_size) + step

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = jax_utils.unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics,
                                       train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
                    f" {train_metric['learning_rate']})")

                train_metrics = []

            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                num_eval_samples = len(tokenized_datasets["validation"])
                # Avoid using jax.numpy here in case of TPU training
                eval_samples_idx = np.arange(num_eval_samples)
                eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                                       eval_batch_size,
                                                       drop_last=False)

                eval_metrics = []
                for i, batch_idx in enumerate(
                        tqdm(eval_batch_idx, desc="Evaluating ...",
                             position=2)):
                    samples = [
                        tokenized_datasets["validation"][int(idx)]
                        for idx in batch_idx
                    ]
                    model_inputs = data_collator(samples)

                    # Model forward
                    metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                        state.params,
                        model_inputs.data,
                        min_device_batch=per_device_eval_batch_size)
                    eval_metrics.append(metrics)

                # normalize eval metrics
                eval_metrics = get_metrics(eval_metrics)
                eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
                eval_normalizer = eval_metrics.pop("normalizer")
                eval_metrics = jax.tree_map(lambda x: x / eval_normalizer,
                                            eval_metrics)

                # Update progress bar
                epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"

                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(
                        jax.tree_map(lambda x: x[0], state.params))
                    model.save_pretrained(training_args.output_dir,
                                          params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=
                            f"Saving weights and logs of step {cur_step}",
                            blocking=False)

    # Eval after training
    if training_args.do_eval:
        num_eval_samples = len(tokenized_datasets["validation"])
        # Avoid using jax.numpy here in case of TPU training
        eval_samples_idx = np.arange(num_eval_samples)
        eval_batch_idx = generate_batch_splits(eval_samples_idx,
                                               eval_batch_size,
                                               drop_last=False)

        eval_metrics = []
        for _, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                tokenized_datasets["validation"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples)

            # Model forward
            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params,
                model_inputs.data,
                min_device_batch=per_device_eval_batch_size)
            eval_metrics.append(metrics)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(),
                                    eval_metrics)
        eval_normalizer = eval_metrics.pop("normalizer")
        eval_metrics = jax.tree_map(lambda x: x / eval_normalizer,
                                    eval_metrics)

        try:
            perplexity = math.exp(eval_metrics["loss"])
        except OverflowError:
            perplexity = float("inf")
        eval_metrics["perplexity"] = perplexity

        if jax.process_index() == 0:
            eval_metrics = {
                f"eval_{metric_name}": value
                for metric_name, value in eval_metrics.items()
            }
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)
Пример #6
0
def bart_or_mbart_model_heuristic(
        model_config: transformers.BartConfig) -> ModelArchitectures:
    if model_config.is_valid_mbart():
        return ModelArchitectures.MBART
    else:
        return ModelArchitectures.BART
Пример #7
0
    def __init__(
        self,
        pretrained_model=None,
        additional_special_tokens_encoder=None,
        additional_special_tokens_decoder=None,
        model_config=None,
        vocab_file=None,
        args=None,
        use_cuda=True,
        cuda_device=-1,
        **kwargs,
    ):
        self.args = self._load_model_args()
        if isinstance(args, dict):
            self.args.update_from_dict(args)
        elif isinstance(args, Seq2SeqArgs):
            self.args = args

        if "sweep_config" in kwargs:
            self.is_sweeping = True
            sweep_config = kwargs.pop("sweep_config")
            sweep_values = sweep_config_to_sweep_values(sweep_config)
            self.args.update_from_dict(sweep_values)
        else:
            self.is_sweeping = False

        if self.args.manual_seed:
            random.seed(self.args.manual_seed)
            np.random.seed(self.args.manual_seed)
            torch.manual_seed(self.args.manual_seed)
            if self.args.n_gpu > 0:
                torch.cuda.manual_seed_all(self.args.manual_seed)

        if use_cuda:
            if torch.cuda.is_available():
                if cuda_device == -1:
                    self.device = torch.device("cuda")
                else:
                    self.device = torch.device(f"cuda:{cuda_device}")
            else:
                raise ValueError(
                    "'use_cuda' set to True when cuda is unavailable."
                    "Make sure CUDA is available or set `use_cuda=False`.")
        else:
            self.device = "cpu"

        self.results = {}

        if not use_cuda:
            self.args.fp16 = False

        # BartConfig, BartForConditionalGeneration, BartTokenizer
        # config = EncoderDecoderConfig.from_encoder_decoder_configs(config, config)
        model_config = BartConfig.from_json_file(model_config)
        if pretrained_model is None:
            self.model = BartForConditionalGeneration(config=model_config)
            self.encoder_tokenizer = BartTokenizer.from_pretrained(vocab_file)

        else:
            self.model = BartForConditionalGeneration.from_pretrained(
                pretrained_model)
            self.encoder_tokenizer = BartTokenizer.from_pretrained(vocab_file)
        self.decoder_tokenizer = self.encoder_tokenizer

        # special AST token
        # additional_special_tokens_encoder = {'additional_special_tokens': ['Assertion', 'RegExp', 'Repetition', 'Quantifier',  'ClassRange', 'CharacterClass']}
        # additional_special_tokens_decoder = {'additional_special_tokens': ['Assertion', 'RegExp', 'Repetition', 'Quantifier',  'ClassRange', 'CharacterClass']}

        self.config = self.model.config

        if additional_special_tokens_encoder is not None:
            self.encoder_tokenizer.add_special_tokens(
                additional_special_tokens_encoder)

        if additional_special_tokens_decoder is not None:
            self.decoder_tokenizer.add_special_tokens(
                additional_special_tokens_decoder)

        if self.args.wandb_project and not wandb_available:
            warnings.warn(
                "wandb_project specified but wandb is not available. Wandb disabled."
            )
            self.args.wandb_project = None

        self.args.model_type = 'bart'
        self.args.model_name = 'ExplainREGEX'
Пример #8
0
def create_model(encoder_name="xlm-roberta-base",
                 dec_vocabsize=None,
                 dec_layers=6,
                 dec_dim=640,
                 dec_heads=8,
                 dropout=0.,
                 dropoutdec=0.,
                 maxlen=50,
                 smoothing=0.,
                 numbeam=1,
                 tensor2tree=None):
    # if encoder_name != "bert-base-uncased":
    #     raise NotImplemented(f"encoder '{encoder_name}' not supported yet.")
    pretrained = AutoModel.from_pretrained(encoder_name)
    encoder = pretrained

    class BertEncoderWrapper(torch.nn.Module):
        def __init__(self, model, dropout=0., **kw):
            super(BertEncoderWrapper, self).__init__(**kw)
            self.model = model
            self.proj = torch.nn.Linear(pretrained.config.hidden_size,
                                        dec_dim,
                                        bias=False)
            self.dropout = torch.nn.Dropout(dropout)

        def forward(self, input_ids, attention_mask=None):
            ret, _ = self.model(input_ids, attention_mask=attention_mask)
            if pretrained.config.hidden_size != dec_dim:
                ret = self.proj(ret)
            # ret = self.dropout(ret)
            ret = (ret, None, None)
            return ret

    encoder = BertEncoderWrapper(encoder, dropout=dropout)

    decoder_config = BartConfig(
        d_model=dec_dim,
        pad_token_id=0,
        bos_token_id=1,
        vocab_size=dec_vocabsize,
        decoder_attention_heads=dec_heads // 2,
        decoder_layers=dec_layers,
        dropout=dropoutdec,
        attention_dropout=min(0.1, dropout / 2),
        decoder_ffn_dim=dec_dim * 4,
        encoder_attention_heads=dec_heads,
        encoder_layers=dec_layers,
        encoder_ffn_dim=dec_dim * 4,
        relative_position=True,
    )
    model = BartGenerator(decoder_config, encoder.model.config)
    model.model.encoder = encoder

    orderless = {"op:and", "SW:concat"}

    trainmodel = BartGeneratorTrain(model,
                                    smoothing=smoothing,
                                    tensor2tree=tensor2tree,
                                    orderless=orderless)
    testmodel = BartGeneratorTest(model,
                                  maxlen=maxlen,
                                  numbeam=numbeam,
                                  tensor2tree=tensor2tree,
                                  orderless=orderless)
    return trainmodel, testmodel
Пример #9
0
import sys, io
import numpy as np
import torch
from transformers import BartTokenizer, BartConfig, BartModel
from tqdm import tqdm
from sklearn.metrics import f1_score, roc_auc_score
import pickle, random
from parabart import ParaBart

print("==== loading model ====")
config = BartConfig.from_pretrained('facebook/bart-base',
                                    cache_dir='../para-data/bart-base')

model = ParaBart(config)

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base',
                                          cache_dir='../para-data/bart-base')

model.load_state_dict(torch.load("./model/model.pt", map_location='cpu'))

model = model.cuda()


def build_embeddings(model, tokenizer, sents):
    model.eval()
    embeddings = torch.ones((len(sents), model.config.d_model))
    with torch.no_grad():
        for i, sent in enumerate(sents):
            sent_inputs = tokenizer(sent, return_tensors="pt")
            sent_token_ids = sent_inputs['input_ids']
Пример #10
0
from flask import Flask, request, render_template
from regression import model1, tokenizer_new, tokenize_new
import numpy as np

from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
import torch
config = BartConfig.from_json_file('output_model/hate/config.json')
model = BartForConditionalGeneration.from_pretrained('output_model/hate/')
tok = BartTokenizer.from_pretrained('output_model/hate/')

app = Flask(__name__)
app.debug = True


@app.route("/", methods=['GET', 'POST'])
def index():
    if request.method == "POST":
        name = request.form["name"]
        hate = " "
        if (len(name) > 0):
            if name.split(" ")[-1] == '':
                a, b, c = tokenize_new([name], tokenizer_new)
                out = np.round(model1.predict([a, b])[0][0])
                if out <= 3:
                    hate = "No Hate detected"
                elif out > 3 and out <= 5:
                    hate = "LOW"
                elif out > 5 and out <= 7:
                    hate = "MEDIUM"
                else:
                    hate = "HIGH"
Пример #11
0
    def build_graph(self):
        """构建模型"""
        if self.config.bart_pre_training:
            self.model = CustomBartGeneration.from_pretrained(
                self.config.bart_pre_training)
            if self.model.config.vocab_size != self.config.vocab_size:
                # 使用预训练模型时词汇表发生变化, 重置embedding表的大小
                self.model.resize_token_embeddings(self.config.vocab_size)
        else:
            bart_config = BartConfig()
            bart_config.activation_function = self.config.activate_func
            bart_config.vocab_size = self.config.vocab_size
            bart_config.d_model = self.config.embed_size
            bart_config.max_position_embeddings = self.config.embed_size
            bart_config.max_length = self.config.max_generate_length
            bart_config.num_labels = self.config.num_labels
            bart_config.image_para_freeze = self.config.image_para_freeze
            bart_config.encoder_layers = self.config.n_layers
            bart_config.decoder_layers = self.config.n_layers
            bart_config.encoder_attention_heads = self.config.n_head
            bart_config.decoder_attention_heads = self.config.n_head
            bart_config.encoder_ffn_dim = self.config.ffn_dim
            bart_config.decoder_ffn_dim = self.config.ffn_dim
            bart_config.pad_token_id = PAD_ID
            bart_config.bos_token_id = BOS_ID
            bart_config.eos_token_id = EOS_ID
            self.model = CustomBartGeneration(config=bart_config)

            # multi-task
            # bart_config.summary_use_proj = True
            # bart_config.summary_activation = None
            # bart_config.summary_first_dropout = True
            # bart_config.summary_proj_to_labels = 0.1
            # bart_config.summary_type = "cls_index"
            # self.model = CustomBartGenerationDoubleHeads(config=bart_config)

        if torch.cuda.is_available():
            self.model.to(self.config.device)

        if self.config.checkpoint:
            self.checkpoint_dict = self.load_model(self.config.checkpoint)

        if self.is_train:
            no_decay = ['bias', 'layer_norm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                self.config.weight_decay
            }, {
                'params': [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            self.optimizer = AdamW(optimizer_grouped_parameters,
                                   lr=self.config.learning_rate,
                                   eps=self.config.adam_epsilon)

            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config.num_warmup_steps,
                num_training_steps=self.config.num_training_steps)
            if self.config.checkpoint and self.checkpoint_dict:
                self.optimizer.load_state_dict(
                    self.checkpoint_dict["optimizer"])  # 加载优化器参数
                self.scheduler.load_state_dict(
                    self.checkpoint_dict["lr_scheduler"])  # 加载lr_scheduler
from transformers import BertModel, BertTokenizer, BertConfig, BartConfig

rubert_ckpt_name = 'DeepPavlov/rubert-base-cased'
tokenizer = BertTokenizer.from_pretrained(
    rubert_ckpt_name, do_lower_case=False)  # do_lower_case=False is crucial
assert tokenizer.padding_side == 'right'
test_text_sample = 'Ай да Пушкин! синхрофазотрон'
assert tokenizer.get_vocab().get('Пушкин') is not None
assert tokenizer.tokenize(test_text_sample) == [
    'Ай', 'да', 'Пушкин', '!', 'синх', '##роф', '##аз', '##отрон'
]
enc_txt = encode_text(tokenizer, test_text_sample, max_len=32)
assert decode_text(tokenizer, enc_txt) == test_text_sample

config = BartConfig.from_pretrained('bart-large-cnn')
rubert_config = BertConfig.from_pretrained(rubert_ckpt_name)
config.model_type = 'rubart'
config.task_specific_params = None
config.vocab_size = rubert_config.vocab_size
config.pad_token_id = rubert_config.pad_token_id
config.bos_token_id = tokenizer.convert_tokens_to_ids('[CLS]')
config.eos_token_id = tokenizer.convert_tokens_to_ids('[SEP]')
config.prefix = None
config.decoder_start_token_id = config.bos_token_id
config.max_position_embeddings = rubert_config.max_position_embeddings

# TODO choose CLS/<S>
print(tokenizer.convert_ids_to_tokens([100, 101, 102, 103, 104, 105, 106,
                                       107]))
Пример #13
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # n_sample for evluating the models during training
    training_args.k_out = data_args.k_out
    training_args.data_dir = data_args.data_dir

    # Ensure output dir is not existed
    if (
        os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir)
        and training_args.do_train and not training_args.overwrite_output_dir
    ):
        raise ValueError(f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.")

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )

    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    config = BartConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
    for p in extra_model_params:
        if getattr(training_args, p, None):
            assert hasattr(config, p), f"({config.__class__.__name__}) doesn't have a `{p}` attribute"
            setattr(config, p, getattr(training_args, p))

    tokenizer = BartTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

    ## TODO special token format: <E1>, <E2>, ... <P1>, <P2> ... 
    special_tokens = ['<E{}>'.format(i) for i in range(data_args.n_special_tokens)] + ['<P{}>'.format(i) for i in range(10)]
    tokenizer.add_tokens(special_tokens)

    model = BartForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        from_tf=".ckpt" in model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
    )
    model.resize_token_embeddings(len(tokenizer))

    # use task specific params, e.g., data_args.task = 'summarization'
    use_task_specific_params(model, data_args.task)

    # set num_beams for evaluation
    if data_args.eval_beams is None:
        data_args.eval_beams = model.config.num_beams

    # set decoder_start_token_id for MBart
    if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
        assert (
            data_args.tgt_lang is not None and data_args.src_lang is not None
        ), "mBart requires --tgt_lang and --src_lang"
        model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]

    if model_args.freeze_embeds:
        freeze_embeds(model)
    if model_args.freeze_encoder:
        freeze_params(model.get_encoder())
        assert_all_frozen(model.get_encoder())

    # Get datasets
    train_dataset = (
        LegacySeq2SeqDataset(
            tokenizer=tokenizer,
            type_path="train",
            data_dir=data_args.data_dir,
            n_obs=data_args.n_train,
            max_target_length=data_args.max_target_length,
            max_source_length=data_args.max_source_length,
            prefix=model.config.prefix or "",
        )
        if training_args.do_train
        else None
    )

    eval_dataset = (
        LegacySeq2SeqDataset(
            tokenizer=tokenizer,
            type_path="val",
            data_dir=data_args.data_dir,
            n_obs=data_args.n_val,  
            max_target_length=data_args.val_max_target_length,
            max_source_length=data_args.max_source_length,
            prefix=model.config.prefix or "",
        )
        if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO
        else None
    )

    test_dataset = (
        LegacySeq2SeqDataset(
            tokenizer=tokenizer,
            type_path="test",
            data_dir=data_args.data_dir,
            n_obs=data_args.n_test,
            max_target_length=data_args.test_max_target_length,
            max_source_length=data_args.max_source_length,
            prefix=model.config.prefix or "",
        )
        if training_args.do_predict
        else None
    )

    trainer = Seq2SeqTrainer(
        model=model,
        config=config,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
        data_args=data_args,
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None)

    # Evaluation (on dev set)
    eval_results = {}
    if training_args.do_eval:

        output = trainer.evaluate()
        predictions = output.predictions.tolist()

        out_pred_path = training_args.output_dir + '/output_pred_dev.txt'
        out_pred_metric = training_args.output_dir + '/output_metric_dev.json'
        out_pred_ref = data_args.data_dir + '/val.target'

        with open(out_pred_path, 'w') as eval_out:
            for pred in predictions:
                output_line = tokenizer.decode(pred, 
                        skip_special_tokens=True, clean_up_tokenization_spaces=False)
                eval_out.write(output_line + '\n')

        metrics = {'epoch': 'eval_mode'}
        metrics.update(eval_top1_acc(out_pred_path, out_pred_ref, data_args.k_out)) ## top1_metrics
        metrics.update(eval_topk_acc(out_pred_path, out_pred_ref, data_args.k_out))  ## topk_metrics
        metrics.update(eval_diversity(out_pred_path, data_args.k_out)) ## diversity_metrics

        with open(out_pred_metric, 'w') as metric_out:
            json.dump(metrics, metric_out, indent=1)

    # Prediction (on test set)
    if training_args.do_predict:
        logging.info("*** Test ***")

        test_output = trainer.predict(test_dataset=test_dataset)
        predictions = test_output.predictions.tolist()

        out_pred_path = training_args.output_dir + '/output_pred_test.txt'
        out_pred_metric = training_args.output_dir + '/output_metric_test.json'
        out_pred_ref = data_args.data_dir + '/test.target'

        with open(out_pred_path, 'w') as eval_out:
            for pred in predictions:
                output_line = tokenizer.decode(pred, 
                        skip_special_tokens=True, clean_up_tokenization_spaces=False)
                eval_out.write(output_line + '\n')

        metrics = {'epoch': 'test_mode'}
        metrics.update(eval_top1_acc(out_pred_path, out_pred_ref, data_args.k_out)) ## top1_metrics
        metrics.update(eval_topk_acc(out_pred_path, out_pred_ref, data_args.k_out))  ## topk_metrics
        metrics.update(eval_diversity(out_pred_path, data_args.k_out)) ## diversity_metrics

        with open(out_pred_metric, 'w') as metric_out:
            json.dump(metrics, metric_out, indent=1)
Пример #14
0
    def __init__(
        self,
        model_name_or_path, # teacher
        tokenizer_name,
        model_cache_dir,
        input_max_length,
        target_max_length,
        summary_column_name,
        document_column_name,
        wandb_project,
        wandb_run_name,
        student_encoder_layers,
        student_decoder_layers,
        **kwargs,
    ):
        super().__init__(
            input_max_length,
            target_max_length,
            summary_column_name,
            document_column_name,
            wandb_project,
            wandb_run_name,
        )
        self.tokenizer = BartTokenizer.from_pretrained(
            tokenizer_name if tokenizer_name else model_name_or_path,
            cache_dir=model_cache_dir,
        )
        teacher = BartForConditionalGeneration.from_pretrained(
            model_name_or_path, cache_dir=model_cache_dir,
        ).eval()

        student_updates = {
            "decoder_layers": student_decoder_layers,
            "encoder_layers": student_encoder_layers,
        }
        d_layers_to_copy = self._get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
        e_layers_to_copy: List = self._get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
        kw = teacher.config.to_diff_dict()
        kw.update(student_updates)
        # Copy weights
        student_cfg = BartConfig(**kw)
        student = BartForConditionalGeneration(student_cfg)
        student, _ = self._init_student(student, teacher)
        self._copy_to_student(d_layers_to_copy, e_layers_to_copy, student_encoder_layers, student_decoder_layers, student, teacher)
        self.model = student
        print(student)
        inputs = self.tokenizer.encode_plus("TEXT TO SUMMARIZE", max_length=1024, return_tensors="pt")

        # Summarize
        outputs = self.model.generate(
            input_ids=inputs['input_ids'], 
            attention_mask=inputs['attention_mask'], 
            max_length=400, 
            min_length=150, 
            length_penalty=2.0, 
            num_beams=4, 
            early_stopping=True
        )

        # Decode
        summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        print(summary)
Пример #15
0
if __name__ == "__main__":

    pd.set_option('display.width', None)
    pd.set_option('display.max_colwidth', None)

    # LOADING MODEL & DATA

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)
    model_created = False

    if args.checkpoint != None:
        model_created = True
        if args.bart:
            config = BartConfig.from_json_file(args.checkpoint +
                                               "/config.json")
            model = BartForConditionalGeneration.from_pretrained(
                args.checkpoint + "/pytorch_model.bin", config=config)
        else:
            config = EncoderDecoderConfig.from_json_file(args.checkpoint +
                                                         "/config.json")
            model = EncoderDecoderModel.from_pretrained(args.checkpoint +
                                                        "/pytorch_model.bin",
                                                        config=config)

    if args.language == 'fr':
        if args.bart:
            model_name = "WikinewsSum/bart-large-multi-fr-wiki-news"
            #config = BartConfig.from_pretrained(model_name)
            tokenizer = BartTokenizer.from_pretrained(model_name)
            if not model_created:
Пример #16
0
 def __init__(self):
     self.config = BartConfig(vocab_size=50264, output_past=True)
     self.bart = load_hf_model(config=self.config, pretrained=True)
Пример #17
0
    'batch_size': 64,
    'tenacity': 5,
    'epoch_size': 4
}

# Set up logger
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        default='bart-large',
                        help='model name or path')
    args = parser.parse_args()

    config = BartConfig.from_pretrained(args.model)
    model = BartModel.from_pretrained(args.model, config=config)
    tokenizer = BartTokenizer.from_pretrained(args.model)

    params_senteval['model'] = model.cuda().eval()
    params_senteval['tokenizer'] = tokenizer
    params_senteval['config'] = config

    se = senteval.engine.SE(params_senteval, batcher, prepare)
    transfer_tasks = [
        'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'MR', 'CR', 'MPQA',
        'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SICKEntailment',
        'SICKRelatedness', 'STSBenchmark', 'Length', 'WordContent', 'Depth',
        'TopConstituents', 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
        'OddManOut', 'CoordinationInversion', 'ImageCaptionRetrieval', 'SNLI'
    ]
Пример #18
0
def create_model(encoder_name="bert-base-uncased",
                 fl_vocab=None,
                 abs_fl_vocab=None,
                 dec_layers=6,
                 dec_dim=640,
                 dec_heads=8,
                 dropout=0.,
                 maxlen=20,
                 smoothing=0.,
                 numbeam=1,
                 abs_id=-100,
                 entropycontrib=1.):

    if encoder_name != "bert-base-uncased":
        raise NotImplementedError(
            f"encoder '{encoder_name}' not supported yet.")
    pretrained = AutoModel.from_pretrained(encoder_name)
    encoder = pretrained

    class BertEncoderWrapper(torch.nn.Module):
        def __init__(self, model, dropout=0., **kw):
            super(BertEncoderWrapper, self).__init__(**kw)
            self.model = model
            self.proj = torch.nn.Linear(pretrained.config.hidden_size,
                                        dec_dim,
                                        bias=False)
            self.dropout = torch.nn.Dropout(dropout)

        def forward(self, input_ids, attention_mask=None):
            ret, _ = self.model(input_ids, attention_mask=attention_mask)
            if pretrained.config.hidden_size != dec_dim:
                ret = self.proj(ret)
            ret = self.dropout(ret)
            ret = (ret, None, None)
            return ret

    class DummyEncoder(torch.nn.Module):
        def __init__(self, dim, **kw):
            super(DummyEncoder, self).__init__(**kw)
            self.dim = dim

        def forward(self, input_ids, attention_mask=None):
            return torch.zeros(input_ids.size(0),
                               1,
                               self.dim,
                               device=input_ids.device)

    encoder = BertEncoderWrapper(encoder, dropout=dropout)
    abs_dec_vocabsize = abs_fl_vocab.number_of_ids()
    dec_vocabsize = fl_vocab.number_of_ids()
    decoder_config = BartConfig(
        d_model=dec_dim,
        pad_token_id=0,
        bos_token_id=1,
        vocab_size=abs_dec_vocabsize,
        decoder_attention_heads=dec_heads // 2,
        decoder_layers=dec_layers,
        dropout=dropout,
        attention_dropout=min(0.1, dropout / 2),
        decoder_ffn_dim=dec_dim * 4,
        encoder_attention_heads=dec_heads,
        encoder_layers=dec_layers,
        encoder_ffn_dim=dec_dim * 4,
    )
    adv_decoder_config = BartConfig(
        d_model=dec_dim,
        pad_token_id=0,
        bos_token_id=1,
        vocab_size=dec_vocabsize,
        decoder_attention_heads=dec_heads // 2,
        decoder_layers=dec_layers,
        dropout=dropout,
        attention_dropout=min(0.1, dropout / 2),
        decoder_ffn_dim=dec_dim * 4,
        encoder_attention_heads=dec_heads,
        encoder_layers=dec_layers,
        encoder_ffn_dim=dec_dim * 4,
    )

    decoder_lm_config = BartConfig(
        d_model=dec_dim,
        pad_token_id=0,
        bos_token_id=1,
        vocab_size=dec_vocabsize,
        decoder_attention_heads=dec_heads // 2,
        decoder_layers=dec_layers,
        dropout=dropout,
        attention_dropout=min(0.1, dropout / 2),
        decoder_ffn_dim=dec_dim * 4,
        encoder_attention_heads=dec_heads,
        encoder_layers=dec_layers,
        encoder_ffn_dim=dec_dim * 4,
    )

    model = BartGenerator(decoder_config)
    model.model.encoder = encoder

    advmodel = BartGenerator(adv_decoder_config)
    advmodel.model.encoder = encoder

    dummy_probs = torch.zeros(dec_vocabsize, dec_vocabsize)
    for k, v in fl_vocab.D.items():
        if k in abs_fl_vocab.D:
            dummy_probs[v, :] = 1
        else:
            for j, w in fl_vocab.D.items():
                if j not in abs_fl_vocab.D and not re.match("@.+@", j):
                    dummy_probs[v, w] = 1
    dummy_probs = dummy_probs / dummy_probs.sum(1)
    decoder_lm = DummyLM(dummy_probs)  # TODO

    orderless = {"op:and", "SW:concat"}

    tensor2tree = partial(_tensor2tree, D=fl_vocab)
    abstensor2tree = partial(_tensor2tree, D=abs_fl_vocab)
    trainmodel = GeneratorTrain(model,
                                advmodel,
                                decoder_lm,
                                smoothing=smoothing,
                                tensor2tree=abstensor2tree,
                                orderless=orderless,
                                abs_id=abs_id,
                                entropycontrib=entropycontrib)
    advtrainmodel = AdversaryTrain(advmodel,
                                   smoothing=smoothing,
                                   tensor2tree=tensor2tree,
                                   orderless=orderless)
    testmodel = GeneratorTest(model,
                              maxlen=maxlen,
                              numbeam=numbeam,
                              tensor2tree=abstensor2tree,
                              orderless=orderless)
    return trainmodel, advtrainmodel, testmodel
Пример #19
0
 def __init__(self, config):
     BartConfig.__init__(self, config)
     torch.nn.Module.__init__(self)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_data_aug",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    processors = {"rte": RteProcessor}

    output_modes = {"rte": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    # label_list = processor.get_labels() #["entailment", "neutral", "contradiction"]
    # label_list = ['How_do_I_create_a_profile_v4', 'Profile_Switch_v4', 'Deactivate_Active_Devices_v4', 'Ads_on_Hulu_v4', 'Watching_Hulu_with_Live_TV_v4', 'Hulu_Costs_and_Commitments_v4', 'offline_downloads_v4', 'womens_world_cup_v5', 'forgot_username_v4', 'confirm_account_cancellation_v4', 'Devices_to_Watch_HBO_on_v4', 'remove_add_on_v4', 'Internet_Speed_for_HD_and_4K_v4', 'roku_related_questions_v4', 'amazon_related_questions_v4', 'Clear_Browser_Cache_v4', 'ads_on_ad_free_plan_v4', 'inappropriate_ads_v4', 'itunes_related_questions_v4', 'Internet_Speed_Recommendations_v4', 'NBA_Basketball_v5', 'unexpected_charges_v4', 'change_billing_date_v4', 'NFL_on_Hulu_v5', 'How_to_delete_a_profile_v4', 'Devices_to_Watch_Hulu_on_v4', 'Manage_your_Hulu_subscription_v4', 'cancel_hulu_account_v4', 'disney_bundle_v4', 'payment_issues_v4', 'home_network_location_v4', 'Main_Menu_v4', 'Resetting_Hulu_Password_v4', 'Update_Payment_v4', 'I_need_general_troubleshooting_help_v4', 'What_is_Hulu_v4', 'sprint_related_questions_v4', 'Log_into_TV_with_activation_code_v4', 'Game_of_Thrones_v4', 'video_playback_issues_v4', 'How_to_edit_a_profile_v4', 'Watchlist_Remove_Video_v4', 'spotify_related_questions_v4', 'Deactivate_Login_Sessions_v4', 'Transfer_to_Agent_v4', 'Use_Hulu_Internationally_v4']

    train_examples, dev_examples, eval_examples, label_list = load_CLINC150_with_specific_domain(
        'banking', 1, augment=args.do_data_aug)
    num_labels = len(label_list)

    # train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        # train_examples = processor.get_RTE_as_train('/export/home/Dataset/glue_data/RTE/train.tsv') #train_pu_half_v1.txt
        # train_examples = get_data_hulu_fewshot('train', 5)

        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    # cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_TRANSFORMERS_CACHE), 'distributed_{}'.format(args.local_rank))

    # pretrain_model_dir = 'roberta-large-mnli' #'roberta-large' , 'roberta-large-mnli'
    # pretrain_model_dir = '/export/home/Dataset/BERT_pretrained_mine/crossdataentail/trainMNLItestRTE/0.8772563176895307'

    model_config = BartConfig.from_pretrained(pretrain_model_dir)
    model_config.num_labels = num_labels
    model = BartForSequenceClassification.from_pretrained(pretrain_model_dir,
                                                          config=model_config)
    # print('after:', model.classification_head.out_proj.out_features)
    # exit(0)

    # tokenizer = RobertaTokenizer.from_pretrained(pretrain_model_dir, do_lower_case=args.do_lower_case)
    tokenizer = BartTokenizer.from_pretrained(pretrain_model_dir,
                                              do_lower_case=args.do_lower_case)

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    max_test_acc = 0.0
    max_dev_acc = 0.0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)
        '''load dev set'''
        # dev_examples = processor.get_RTE_as_dev('/export/home/Dataset/glue_data/RTE/dev.tsv')
        # dev_examples = get_data_hulu('dev')
        dev_features = convert_examples_to_features(
            dev_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        dev_all_input_ids = torch.tensor([f.input_ids for f in dev_features],
                                         dtype=torch.long)
        dev_all_input_mask = torch.tensor([f.input_mask for f in dev_features],
                                          dtype=torch.long)
        dev_all_segment_ids = torch.tensor(
            [f.segment_ids for f in dev_features], dtype=torch.long)
        dev_all_label_ids = torch.tensor([f.label_id for f in dev_features],
                                         dtype=torch.long)

        dev_data = TensorDataset(dev_all_input_ids, dev_all_input_mask,
                                 dev_all_segment_ids, dev_all_label_ids)
        dev_sampler = SequentialSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.eval_batch_size)
        '''load test set'''
        # eval_examples = processor.get_RTE_as_test('/export/home/Dataset/RTE/test_RTE_1235.txt')
        # eval_examples = get_data_hulu('test')
        eval_features = convert_examples_to_features(
            eval_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            output_mode,
            cls_token_at_end=
            False,  #bool(args.model_type in ['xlnet']),            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=0,  #2 if args.model_type in ['xlnet'] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=
            True,  #bool(args.model_type in ['roberta']),           # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=
            False,  #bool(args.model_type in ['xlnet']),                 # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token
                                                       ])[0],
            pad_token_segment_id=0
        )  #4 if args.model_type in ['xlnet'] else 0,)

        eval_all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                          dtype=torch.long)
        eval_all_input_mask = torch.tensor(
            [f.input_mask for f in eval_features], dtype=torch.long)
        eval_all_segment_ids = torch.tensor(
            [f.segment_ids for f in eval_features], dtype=torch.long)
        eval_all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                          dtype=torch.long)

        eval_data = TensorDataset(eval_all_input_ids, eval_all_input_mask,
                                  eval_all_segment_ids, eval_all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        iter_co = 0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                outputs = model(input_ids, input_mask, labels=label_ids)
                # loss_fct = CrossEntropyLoss()
                loss = outputs[
                    0]  #loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                iter_co += 1
                # if iter_co %20==0:
                if iter_co % len(train_dataloader) == 0:
                    '''
                    start evaluate on dev set after this epoch
                    '''
                    model.eval()

                    for idd, dev_or_test_dataloader in enumerate(
                        [dev_dataloader, eval_dataloader]):

                        if idd == 0:
                            logger.info("***** Running dev *****")
                            logger.info("  Num examples = %d",
                                        len(dev_examples))
                        else:
                            logger.info("***** Running test *****")
                            logger.info("  Num examples = %d",
                                        len(eval_examples))
                        # logger.info("  Batch size = %d", args.eval_batch_size)

                        eval_loss = 0
                        nb_eval_steps = 0
                        preds = []
                        gold_label_ids = []
                        # print('Evaluating...')
                        for input_ids, input_mask, segment_ids, label_ids in dev_or_test_dataloader:
                            input_ids = input_ids.to(device)
                            input_mask = input_mask.to(device)
                            segment_ids = segment_ids.to(device)
                            label_ids = label_ids.to(device)
                            gold_label_ids += list(
                                label_ids.detach().cpu().numpy())

                            with torch.no_grad():
                                logits = model(input_ids,
                                               input_mask,
                                               labels=None)
                            # print('logits:', logits)
                            logits = logits[0]

                            loss_fct = CrossEntropyLoss()
                            tmp_eval_loss = loss_fct(
                                logits.view(-1, num_labels),
                                label_ids.view(-1))

                            eval_loss += tmp_eval_loss.mean().item()
                            nb_eval_steps += 1
                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)

                        eval_loss = eval_loss / nb_eval_steps
                        preds = preds[0]
                        '''
                        preds: size*3 ["entailment", "neutral", "contradiction"]
                        wenpeng added a softxmax so that each row is a prob vec
                        '''
                        pred_probs = softmax(preds, axis=1)
                        pred_label_ids = list(np.argmax(pred_probs, axis=1))
                        # pred_indices = np.argmax(pred_probs, axis=1)

                        # pred_label_ids = []
                        # for p in pred_indices:
                        #     pred_label_ids.append(0 if p == 0 else 1)

                        gold_label_ids = gold_label_ids
                        assert len(pred_label_ids) == len(gold_label_ids)
                        hit_co = 0
                        for k in range(len(pred_label_ids)):
                            if pred_label_ids[k] == gold_label_ids[k]:
                                hit_co += 1
                        test_acc = hit_co / len(gold_label_ids)

                        if idd == 0:  # this is dev
                            if test_acc > max_dev_acc:
                                max_dev_acc = test_acc
                                print('\ndev acc:', test_acc, ' max_dev_acc:',
                                      max_dev_acc, '\n')

                            else:
                                print('\ndev acc:', test_acc, ' max_dev_acc:',
                                      max_dev_acc, '\n')
                                break
                        else:  # this is test
                            if test_acc > max_test_acc:
                                max_test_acc = test_acc
                            print('\ntest acc:', test_acc, ' max_test_acc:',
                                  max_test_acc, '\n')