Ejemplo n.º 1
0
def create_model(args, vocab_size, device):
    """
    Creates the classifier and encoder model.
    """
    config_path = join(args.model_dir, 'config.json')

    if not exists(config_path):
        config = XLNetConfig.from_pretrained('xlnet-base-cased')

        generator = XLNetGenerator(config)
        generator.resize_token_embeddings(vocab_size)

        config.save_pretrained(args.model_dir)

    # TODO huggingface output bias layer is bugged
    # if the size of the embeddings is modified
    # reloading the model with new config
    # fixes the problem
    config = XLNetConfig.from_pretrained(args.model_dir)

    generator = XLNetGenerator(config)

    generator = generator.to(device)

    return generator
Ejemplo n.º 2
0
def xlnet_feature_extractor(examples):
    config = XLNetConfig.from_pretrained(model_name)
    tokenizer = XLNetTokenizer.from_pretrained(model_name)
    model = XLNetForFeatureExtraction(config)
    # input_ids = torch.tensor(tokenizer.encode(utterance_list[0])).unsqueeze(0)# Batch size 1

    features = convert_examples_to_features(
        examples,
        MAX_SEQ_LEN,
        tokenizer,
        cls_token_at_end=True,  # xlnet has a cls token at the end
        cls_token=tokenizer.cls_token,
        sep_token=tokenizer.sep_token,
        cls_token_segment_id=2 if True else 1,
        pad_on_left=True,  # pad on the left for xlnet
        pad_token_segment_id=4 if True else 0)

    input_ids_list = []
    input_mask_list = []
    segment_ids_list = []
    for feature in features:
        input_ids_list.append(feature.input_ids)
        input_mask_list.append(feature.input_mask)
        segment_ids_list.append(feature.segment_ids)

    input_ids_tensor = torch.tensor(input_ids_list)
    input_mask_tensor = torch.tensor(input_mask_list)
    segment_ids_tensor = torch.tensor(segment_ids_list)

    transformer_outputs = model(input_ids=input_ids_tensor,
                                attention_mask=input_mask_tensor,
                                token_type_ids=segment_ids_tensor)
    feature = transformer_outputs[:, -1]
    return feature
Ejemplo n.º 3
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if pt != '':
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if k in model_flags:
            setattr(args, k, opt[k])
    print(args)

    config = XLNetConfig.from_pretrained(args.config_path)
    model = Summarizer(args,
                       device,
                       load_pretrained_bert=False,
                       bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    valid_iter = Dataloader(args,
                            load_dataset(args, 'valid', shuffle=False),
                            args.batch_size,
                            device,
                            shuffle=False,
                            is_test=False)
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Ejemplo n.º 4
0
def train(args, device):
    args.dataset_name = "MNLI"  # TODO: parametrize

    model_name = args.model_name
    log = get_train_logger(args)
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    log.info(f'Using device {device}')
    tokenizer = XLNetTokenizer.from_pretrained(model_name, do_lower_case=True)
    xlnet_config = XLNetConfig.from_pretrained(
        model_name,
        output_hidden_states=True,
        output_attentions=True,
        num_labels=3,
        finetuning_task=args.dataset_name)

    model = XLNetForSequenceClassification.from_pretrained(model_name,
                                                           config=xlnet_config)

    model.to(device)

    # Load features from datasets
    data_loader = MNLIDatasetReader(args, tokenizer, log)
    train_file = os.path.join(args.base_path, args.train_file)
    val_file = os.path.join(args.base_path, args.val_file)
    train_dataloader = data_loader.load_train_dataloader(train_file)
    val_dataloader = data_loader.load_val_dataloader(val_file)

    trainer = TrainModel(train_dataloader, val_dataloader, log)
    trainer.train(model, device, args)
Ejemplo n.º 5
0
def main():
    torch.cuda.empty_cache()
    parser = setup_parser()
    args = parser.parse_args()
    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir
    ) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory already exists and is not empty.")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.n_gpu = torch.cuda.device_count()
    args.device = device
    set_seed(args)
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: {}".format(args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    ##Load Models
    config = XLNetConfig.from_pretrained(args.config_name)
    print('config: {}'.format(config))
    tokenizer = XLNetTokenizer.from_pretrained(
        args.text_encoder_checkpoint, do_lower_case=args.do_lower_case)
    text_encoder = XLNetModel.from_pretrained(args.text_encoder_checkpoint,
                                              config=config)
    graph_encoder = GraphEncoder(args.n_hidden, args.min_score)
    if args.graph_encoder_checkpoint:
        graph_encoder.gcnnet.load_state_dict(
            torch.load(args.graph_encoder_checkpoint))

    medsts_classifier = PairClassifier(config.hidden_size + args.n_hidden, 1)
    medsts_c_classifier = PairClassifier(config.hidden_size + args.n_hidden, 5)
    medsts_type_classifier = PairClassifier(config.hidden_size + args.n_hidden,
                                            4)
    model = MedstsNet(text_encoder, graph_encoder, medsts_classifier,
                      medsts_c_classifier, medsts_type_classifier, config)
    model.to(args.device)

    args.n_gpu = 1

    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                args.task_name,
                                                tokenizer,
                                                evaluate=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info('global step = {}, average loss = {}'.format(
            global_step, tr_loss))
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        logger.info("saving model checkpoint to {}".format(args.output_dir))
        model_to_save = model.module if hasattr(model, 'module') else model
        # model_to_save.save_pretrained(args.output_dir)
        torch.save(model_to_save.state_dict(),
                   os.path.join(args.output_dir, 'saved_model.pth'))
        tokenizer.save_pretrained(args.output_dir)
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
Ejemplo n.º 6
0
    def __init__(
        self,
        language=Language.ENGLISHCASED,
        num_labels=5,
        cache_dir=".",
        num_gpus=None,
        num_epochs=1,
        batch_size=8,
        lr=5e-5,
        adam_eps=1e-8,
        warmup_steps=0,
        weight_decay=0.0,
        max_grad_norm=1.0,
    ):
        """Initializes the classifier and the underlying pretrained model.

        Args:
            language (Language, optional): The pretrained model's language.
                                           Defaults to 'xlnet-base-cased'.
            num_labels (int, optional): The number of unique labels in the
                training data. Defaults to 5.
            cache_dir (str, optional): Location of XLNet's cache directory.
                Defaults to ".".
            num_gpus (int, optional): The number of gpus to use.
                                      If None is specified, all available GPUs
                                      will be used. Defaults to None.
            num_epochs (int, optional): Number of training epochs.
                Defaults to 1.
            batch_size (int, optional): Training batch size. Defaults to 8.
            lr (float): Learning rate of the Adam optimizer. Defaults to 5e-5.
            adam_eps (float, optional): term added to the denominator to improve
                                        numerical stability. Defaults to 1e-8.
            warmup_steps (int, optional): Number of steps in which to increase
                                        learning rate linearly from 0 to 1. Defaults to 0.
            weight_decay (float, optional): Weight decay. Defaults to 0.
            max_grad_norm (float, optional): Maximum norm for the gradients. Defaults to 1.0
        """

        if num_labels < 2:
            raise ValueError("Number of labels should be at least 2.")

        self.language = language
        self.num_labels = num_labels
        self.cache_dir = cache_dir

        self.num_gpus = num_gpus
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.adam_eps = adam_eps
        self.warmup_steps = warmup_steps
        self.weight_decay = weight_decay
        self.max_grad_norm = max_grad_norm

        # create classifier
        self.config = XLNetConfig.from_pretrained(self.language.value,
                                                  num_labels=num_labels,
                                                  cache_dir=cache_dir)
        self.model = XLNetForSequenceClassification(self.config)
Ejemplo n.º 7
0
    def __init__(
        self,
        gpu=-1,
        check_for_lowercase=True,
        embeddings_dim=0,
        verbose=True,
        path_to_pretrained="xlnet-base-cased",
        model_frozen=True,
        bos_token="<s>",
        eos_token="</s>",
        unk_token="<unk>",
        sep_token="<sep>",
        pad_token="<pad>",
        cls_token="<cls>",
        mask_token="<mask>",
    ):
        SeqIndexerBaseEmbeddings.__init__(
            self,
            gpu=gpu,
            check_for_lowercase=check_for_lowercase,
            zero_digits=True,
            bos_token=bos_token,
            eos_token=eos_token,
            pad=pad_token,
            unk=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            mask_token=mask_token,
            load_embeddings=True,
            embeddings_dim=embeddings_dim,
            verbose=verbose,
            isBert=False,
            isXlNet=True)

        print("create seq indexer Transformers from Model {}".format(
            path_to_pretrained))

        self.xlnet = True

        self.path_to_pretrained = path_to_pretrained
        self.tokenizer = XLNetTokenizer.from_pretrained(path_to_pretrained)
        self.config = XLNetConfig.from_pretrained(path_to_pretrained)
        self.emb = XLNetModel.from_pretrained(path_to_pretrained)
        self.frozen = model_frozen
        for param in self.emb.parameters():
            param.requires_grad = False
        for elem in [
                self.emb.word_embedding, self.emb.layer, self.emb.dropout
        ]:
            for param in elem.parameters():
                param.requires_grad = False

        if (not self.frozen):
            for param in self.emb.pooler.parameters():
                param.requires_grad = True
        self.emb.eval()
        print("XLNET model loaded succesifully")
Ejemplo n.º 8
0
def main():
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # ------------------判断CUDA模式----------------------
    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        n_gpu = 1

    #produce data
    train_batch_size = args.per_gpu_train_batch_size * max(1, n_gpu)
    eval_batch_size = args.per_gpu_eval_batch_size * max(1, n_gpu)

    train_iter = load_and_cache_examples(mode='train',train_batch_size=train_batch_size,
                                                          eval_batch_size=eval_batch_size)
    eval_iter = load_and_cache_examples(mode='dev',
                                        train_batch_size=train_batch_size,
                                        eval_batch_size=eval_batch_size)

    #epoch_size = num_train_steps * train_batch_size * args.gradient_accumulation_steps / args.num_train_epochs

    # pbar = ProgressBar(epoch_size=epoch_size,
    #                    batch_size=train_batch_size)

    if args.model_type == 'bert':
        model = Bert_SenAnalysis.from_pretrained(args.bert_model, num_tag = len(args.labels))
    elif args.model_type == 'xlnet':
        config = XLNetConfig.from_pretrained(args.xlnet_model, num_labels = len(args.labels))
        model = XLNet_SenAnalysis.from_pretrained(args.xlnet_model, config=config)
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)

    train_iter = cycle(train_iter)
    fit(model = model,
        training_iter=train_iter,
        eval_iter=eval_iter,
        #train_steps=args.train_steps,
        #pbar=pbar,
        num_train_steps=args.train_steps,#num_train_steps,
        device=device,
        n_gpu=n_gpu,
        verbose=1)
Ejemplo n.º 9
0
        def prepare_config_and_inputs(self):
            input_ids_1 = ids_tensor([self.batch_size, self.seq_length],
                                     self.vocab_size)
            input_ids_2 = ids_tensor([self.batch_size, self.seq_length],
                                     self.vocab_size)
            segment_ids = ids_tensor([self.batch_size, self.seq_length],
                                     self.type_vocab_size)
            input_mask = ids_tensor([self.batch_size, self.seq_length],
                                    2).float()

            input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1],
                                     self.vocab_size)
            perm_mask = torch.zeros(self.batch_size,
                                    self.seq_length + 1,
                                    self.seq_length + 1,
                                    dtype=torch.float)
            perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
            target_mapping = torch.zeros(self.batch_size,
                                         1,
                                         self.seq_length + 1,
                                         dtype=torch.float)
            target_mapping[:, 0, -1] = 1.0  # predict last token

            sequence_labels = None
            lm_labels = None
            is_impossible_labels = None
            if self.use_labels:
                lm_labels = ids_tensor([self.batch_size, self.seq_length],
                                       self.vocab_size)
                sequence_labels = ids_tensor([self.batch_size],
                                             self.type_sequence_label_size)
                is_impossible_labels = ids_tensor([self.batch_size], 2).float()

            config = XLNetConfig(
                vocab_size_or_config_json_file=self.vocab_size,
                d_model=self.hidden_size,
                n_head=self.num_attention_heads,
                d_inner=self.d_inner,
                n_layer=self.num_hidden_layers,
                untie_r=self.untie_r,
                max_position_embeddings=self.max_position_embeddings,
                mem_len=self.mem_len,
                clamp_len=self.clamp_len,
                same_length=self.same_length,
                reuse_len=self.reuse_len,
                bi_data=self.bi_data,
                initializer_range=self.initializer_range,
                num_labels=self.type_sequence_label_size)

            return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask,
                    input_mask, target_mapping, segment_ids, lm_labels,
                    sequence_labels, is_impossible_labels)
Ejemplo n.º 10
0
def run(args):
    nli_model_path = 'saved_models/xlnet-base-cased/'
    model_file = os.path.join(nli_model_path, 'pytorch_model.bin')
    config_file = os.path.join(nli_model_path, 'config.json')
    log = get_logger('conduct_test')
    model_name = 'xlnet-base-cased'
    tokenizer = XLNetTokenizer.from_pretrained(model_name, do_lower_case=True)
    xlnet_config = XLNetConfig.from_pretrained(config_file)
    model = XLNetForSequenceClassification.from_pretrained(model_file,
                                                           config=xlnet_config)
    dataset_reader = ConductDatasetReader(args, tokenizer, log)
    file_lines = dataset_reader.get_file_lines('data/dados.tsv')

    results = []
    softmax_fn = torch.nn.Softmax(dim=1)

    model.eval()
    with torch.no_grad():
        for line in tqdm(file_lines):
            premise, hypothesys, conflict = dataset_reader.parse_line(line)
            pair_word_ids, input_mask, pair_segment_ids = dataset_reader.convert_text_to_features(
                premise, hypothesys)
            tensor_word_ids = torch.tensor([pair_word_ids],
                                           dtype=torch.long,
                                           device=args.device)
            tensor_input_mask = torch.tensor([input_mask],
                                             dtype=torch.long,
                                             device=args.device)
            tensor_segment_ids = torch.tensor([pair_segment_ids],
                                              dtype=torch.long,
                                              device=args.device)
            model_input = {
                'input_ids': tensor_word_ids,  # word ids
                'attention_mask': tensor_input_mask,  # input mask
                'token_type_ids': tensor_segment_ids
            }
            outputs = model(**model_input)
            logits = outputs[0]
            nli_scores, nli_class = get_scores_and_class(logits, softmax_fn)
            nli_scores = nli_scores.detach().cpu().numpy()
            results.append({
                "conduct": premise,
                "complaint": hypothesys,
                "nli_class": nli_class,
                "nli_contradiction_score": nli_scores[0],
                "nli_entailment_score": nli_scores[1],
                "nli_neutral_score": nli_scores[2],
                "conflict": conflict
            })

    df = pd.DataFrame(results)
    df.to_csv('results/final_results.tsv', sep='\t', index=False)
Ejemplo n.º 11
0
    def load(cls,
             config_path: Path,
             model_path: Path,
             cache_model: bool = True) -> XLNetModel:
        if model_path in cls._cache:
            return PretrainedXLNetModel._cache[str(model_path)]

        config = XLNetConfig.from_pretrained(str(config_path))
        model = XLNetModel.from_pretrained(str(model_path), config=config)
        if cache_model:
            cls._cache[str(model_path)] = model

        return model
Ejemplo n.º 12
0
def load_model(output_dir, model_type):
    # Load a trained model that you have fine-tuned
    output_model_file = os.path.join(output_dir, "pytorch_model.bin")
    model_state_dict = torch.load(output_model_file)
    if model_type == 'bert':
        model = Bert_SenAnalysis.from_pretrained(args.bert_model,
                                                 state_dict=model_state_dict,
                                                 num_tag=len(args.labels))
    elif model_type == 'xlnet':
        config = XLNetConfig.from_pretrained(args.xlnet_model,
                                             num_labels=len(args.labels))
        model = XLNet_SenAnalysis.from_pretrained(args.xlnet_model,
                                                  config=config)
    return model
Ejemplo n.º 13
0
    def __init__(self, args, dictionary, embed_tokens, left_pad=True):
        super().__init__(dictionary)
        self.dropout = args.dropout
        self.n_gpu = torch.cuda.device_count()

        print('Distributed rank: ', args.distributed_rank)
        print('Number of used GPU: ', self.n_gpu)

        # if self.n_gpu > 1:
        #     torch.distributed.barrier()
        if args.distributed_rank not in [-1, 0]:
            torch.distributed.barrier(
            )  # Make sure only the first process in distributed training will download model & vocab

        # Load pre-trained model (weights)
        config = XLNetConfig.from_pretrained(args.xlnet_model)
        self.xlnet = XLNetModel.from_pretrained(args.xlnet_model,
                                                config=config)

        if args.distributed_rank == 0:
            torch.distributed.barrier(
            )  # Make sure only the first process in distributed training will download model & vocab

        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = args.max_source_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_source_positions,
            embed_dim,
            self.padding_idx,
            left_pad=left_pad,
            learned=args.encoder_learned_pos,
        ) if not args.no_token_positional_embeddings else None

        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransformerEncoderLayer(args) for i in range(args.encoder_layers)
        ])
        self.register_buffer('version', torch.Tensor([2]))
        self.normalize = args.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)
Ejemplo n.º 14
0
    def init_params(self, model_name, pre_trained_model, f_lr=5e-5, f_eps = 1e-8):
        MODEL_CLASSES   = { "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer) }
        # self._config, self._model_class, self._tokenizer = MODEL_CLASSES[model_name]
        self._tokenizer = XLNetTokenizer.from_pretrained(pre_trained_model, do_lower_case=True)
        self._config    = XLNetConfig.from_pretrained(pre_trained_model, do_lower_case=True)
        self._model     = XLNetForQuestionAnswering.from_pretrained(pre_trained_model, config=self._config)
        self._model.to(self._device)

        no_decay = ['bias', 'LayerNorm.weight']
        weight_decay = 0.0 # Author's default parameter
        optimizer_grouped_parameters = [
                  {'params': [p for n, p in self._model.named_parameters() \
                          if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
                  {'params': [p for n, p in self._model.named_parameters() \
                          if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                  ]
        # warmup_steps = 0.0
        self._optimizer = AdamW(optimizer_grouped_parameters, lr=f_lr, eps=f_eps)
Ejemplo n.º 15
0
    def __init__(self, model_path):
        super(OnmtXLNetEncoder, self).__init__()
        config = XLNetConfig.from_json_file(
            os.path.join(model_path, "config.json"))
        pretrained_dict = os.path.join(model_path, "pytorch_model.bin")
        if os.path.exists(pretrained_dict):
            model = XLNetModel.from_pretrained(
                pretrained_model_name_or_path=pretrained_dict, config=config)
            print("init XLNet model with {} weights".format(
                len(model.state_dict())))
        else:
            model = XLNetModel(config)

        model.word_embedding = expandEmbeddingByN(model.word_embedding, 4)
        model.word_embedding = expandEmbeddingByN(model.word_embedding,
                                                  2,
                                                  last=True)
        self.encoder = model
        #print(model)
        print("***" * 20)
Ejemplo n.º 16
0
    def __init__(self, args, task_name, weight_file=None, config_file=None):
        self.args = args
        self.device = args.device
        self.log = self.get_train_logger(args, task_name)
        self.softmax = Softmax(dim=1)
        self.tokenizer = XLNetTokenizer.from_pretrained(args.model_name,
                                                        do_lower_case=True)
        self.dataset_reader = init_dataset_reader(task_name, args,
                                                  self.tokenizer, self.log)

        config = args.model_name if config_file is None else config_file
        model_weights = args.model_name if weight_file is None else weight_file
        xlnet_config = XLNetConfig.from_pretrained(config,
                                                   output_hidden_states=True,
                                                   output_attentions=True,
                                                   num_labels=3,
                                                   finetuning_task=task_name)

        model = XLNetForSequenceClassification.from_pretrained(
            model_weights, config=xlnet_config)
        self.model = model.to(args.device)
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = XLNet(args.large, args.temp_dir, args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, args.ext_ff_size,
            args.ext_heads, args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = XLNetConfig(self.bert.model.config.vocab_size,
                                      hidden_size=args.ext_hidden_size,
                                      num_hidden_layers=args.ext_layers,
                                      num_attention_heads=args.ext_heads,
                                      intermediate_size=args.ext_ff_size)
            self.bert.model = XLNetModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)
Ejemplo n.º 18
0
def validate_on_test_set(args, device):
    log = get_logger(f"test-results")
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    log.info(f'Using device {device}')

    model_name = 'xlnet-base-cased'
    tokenizer = XLNetTokenizer.from_pretrained(model_name, do_lower_case=True)
    xlnet_config = XLNetConfig.from_pretrained(args.config_file)
    data_reader = KaggleMNLIDatasetReader(args, tokenizer, log)
    model = XLNetForSequenceClassification.from_pretrained(args.model_file, config=xlnet_config)

    model.to(device)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    log.info(f'Running on {args.n_gpu} GPUS')

    test_executor = KaggleTest(tokenizer, log, data_reader)
    write_kaggle_results("matched", args.test_matched_file, test_executor, device, model)
    write_kaggle_results("mismatched", args.test_mismatched_file, test_executor, device, model)
 def __init__(self, num_labels=2, model_type='xlnet-base-cased',token_layer='token-cls',output_logits=True):
     super(XLNetForWSD, self).__init__()
     
     self.config = XLNetConfig()
     self.token_layer = token_layer
     self.num_labels = 2
     self.xlnet = XLNetModel.from_pretrained(model_type)
     self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
     self.output_logits = output_logits
     
     # Define which token selection layer to use
     if token_layer == 'token-cls':
         self.tokenselectlayer = TokenClsLayer()
     elif token_layer in ['sent-cls','sent-cls-ws']:
         self.tokenselectlayer = SentClsLayer()
     else:
         raise ValueError("Unidentified parameter for token selection layer")       
              
     self.classifier = nn.Linear(768, num_labels)
     if not output_logits:
         self.softmax = nn.Softmax(dim=1) # to be checked!!!
     
     nn.init.xavier_normal_(self.classifier.weight)   
def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path,
                                        bert_config_file,
                                        pytorch_dump_folder_path,
                                        finetuning_task=None):
    # Initialise PyTorch model
    config = XLNetConfig.from_json_file(bert_config_file)

    finetuning_task = finetuning_task.lower(
    ) if finetuning_task is not None else ""
    if finetuning_task in GLUE_TASKS_NUM_LABELS:
        print(
            "Building PyTorch XLNetForSequenceClassification model from configuration: {}"
            .format(str(config)))
        config.finetuning_task = finetuning_task
        config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
        model = XLNetForSequenceClassification(config)
    elif 'squad' in finetuning_task:
        config.finetuning_task = finetuning_task
        model = XLNetForQuestionAnswering(config)
    else:
        model = XLNetLMHeadModel(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)

    # Save pytorch-model
    pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path,
                                             WEIGHTS_NAME)
    pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path,
                                            CONFIG_NAME)
    print("Save PyTorch model to {}".format(
        os.path.abspath(pytorch_weights_dump_path)))
    torch.save(model.state_dict(), pytorch_weights_dump_path)
    print("Save configuration file to {}".format(
        os.path.abspath(pytorch_config_dump_path)))
    with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
        f.write(config.to_json_string())
    def __init__(self,
                 args,
                 device,
                 checkpoint=None,
                 bert_from_extractive=None):
        super(AbsSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = XLNet(args.large, args.temp_dir, args.finetune_bert)

        if bert_from_extractive is not None:
            self.bert.model.load_state_dict(dict([
                (n[11:], p) for n, p in bert_from_extractive.items()
                if n.startswith('bert.model')
            ]),
                                            strict=True)

        if (args.encoder == 'baseline'):
            bert_config = XLNetConfig(
                self.bert.model.config.vocab_size,
                hidden_size=args.enc_hidden_size,
                num_hidden_layers=args.enc_layers,
                num_attention_heads=8,
                intermediate_size=args.enc_ff_size,
                hidden_dropout_prob=args.enc_dropout,
                attention_probs_dropout_prob=args.enc_dropout)
            self.bert.model = XLNetModel(bert_config)

        if (args.max_pos > 512):
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        self.vocab_size = self.bert.model.config.vocab_size
        tgt_embeddings = nn.Embedding(self.vocab_size,
                                      self.bert.model.config.hidden_size,
                                      padding_idx=0)
        if (self.args.share_emb):
            tgt_embeddings.weight = copy.deepcopy(
                self.bert.model.embeddings.word_embeddings.weight)

        self.decoder = TransformerDecoder(self.args.dec_layers,
                                          self.args.dec_hidden_size,
                                          heads=self.args.dec_heads,
                                          d_ff=self.args.dec_ff_size,
                                          dropout=self.args.dec_dropout,
                                          embeddings=tgt_embeddings)

        self.generator = get_generator(self.vocab_size,
                                       self.args.dec_hidden_size, device)
        self.generator[0].weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            for module in self.decoder.modules():
                if isinstance(module, (nn.Linear, nn.Embedding)):
                    module.weight.data.normal_(mean=0.0, std=0.02)
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            for p in self.generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
                else:
                    p.data.zero_()
            if (args.use_bert_emb):
                tgt_embeddings = nn.Embedding(
                    self.vocab_size,
                    self.bert.model.config.hidden_size,
                    padding_idx=0)
                tgt_embeddings.weight = copy.deepcopy(
                    self.bert.model.word_embedding.weight)
                self.decoder.embeddings = tgt_embeddings
                self.generator[0].weight = self.decoder.embeddings.weight
        self.to(device)
Ejemplo n.º 22
0
def main(_):
    if FLAGS.server_ip and FLAGS.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(FLAGS.server_ip, FLAGS.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    tf.set_random_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)

    tf.logging.set_verbosity(tf.logging.INFO)

    #### Validate flags
    if FLAGS.save_steps is not None:
        FLAGS.log_step_count_steps = min(FLAGS.log_step_count_steps,
                                         FLAGS.save_steps)

    if FLAGS.do_predict:
        predict_dir = FLAGS.predict_dir
        if not tf.gfile.Exists(predict_dir):
            tf.gfile.MakeDirs(predict_dir)

    processors = {
        "mnli_matched": MnliMatchedProcessor,
        "mnli_mismatched": MnliMismatchedProcessor,
        'sts-b': StsbProcessor,
        'imdb': ImdbProcessor,
        "yelp5": Yelp5Processor
    }

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval, `do_predict` or "
            "`do_submit` must be True.")

    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MakeDirs(FLAGS.output_dir)

    if not tf.gfile.Exists(FLAGS.model_dir):
        tf.gfile.MakeDirs(FLAGS.model_dir)

#   ########################### LOAD PT model
#   ########################### LOAD PT model
#   import torch
#   from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification

#   save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME)
#   tf.logging.info("Model loaded from path: {}".format(save_path))

#   device = torch.device("cuda", 4)
#   config = XLNetConfig.from_pretrained('xlnet-large-cased', finetuning_task=u'sts-b')
#   config_path = os.path.join(FLAGS.model_dir, CONFIG_NAME)
#   config.to_json_file(config_path)
#   pt_model = XLNetForSequenceClassification.from_pretrained(FLAGS.model_dir, from_tf=True, num_labels=1)
#   pt_model.to(device)
#   pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])

#   from torch.optim import Adam
#   optimizer = Adam(pt_model.parameters(), lr=0.001, betas=(0.9, 0.999),
#                     eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay,
#                     amsgrad=False)
#   ########################### LOAD PT model
#   ########################### LOAD PT model

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels() if not FLAGS.is_regression else None

    sp = spm.SentencePieceProcessor()
    sp.Load(FLAGS.spiece_model_file)

    def tokenize_fn(text):
        text = preprocess_text(text, lower=FLAGS.uncased)
        return encode_ids(sp, text)

    # run_config = model_utils.configure_tpu(FLAGS)


#   model_fn = get_model_fn(len(label_list) if label_list is not None else None)

    spm_basename = os.path.basename(FLAGS.spiece_model_file)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    # estimator = tf.estimator.Estimator(
    #     model_fn=model_fn,
    #     config=run_config)

    if FLAGS.do_train:
        train_file_base = "{}.len-{}.train.tf_record".format(
            spm_basename, FLAGS.max_seq_length)
        train_file = os.path.join(FLAGS.output_dir, train_file_base)
        tf.logging.info("Use tfrecord file {}".format(train_file))

        train_examples = processor.get_train_examples(FLAGS.data_dir)
        tf.logging.info("Num of train samples: {}".format(len(train_examples)))

        file_based_convert_examples_to_features(train_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenize_fn, train_file,
                                                FLAGS.num_passes)

        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)

        # estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)

        ##### Create input tensors / placeholders
        bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host

        params = {
            "batch_size": FLAGS.train_batch_size  # the whole batch
        }
        train_set = train_input_fn(params)

        example = train_set.make_one_shot_iterator().get_next()
        if FLAGS.num_core_per_host > 1:
            examples = [{} for _ in range(FLAGS.num_core_per_host)]
            for key in example.keys():
                vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
                for device_id in range(FLAGS.num_core_per_host):
                    examples[device_id][key] = vals[device_id]
        else:
            examples = [example]

        ##### Create computational graph
        tower_losses, tower_grads_and_vars, tower_inputs, tower_hidden_states, tower_logits = [], [], [], [], []

        for i in range(FLAGS.num_core_per_host):
            reuse = True if i > 0 else None
            with tf.device(assign_to_gpu(i, "/gpu:0")), \
                tf.variable_scope(tf.get_variable_scope(), reuse=reuse):

                loss_i, grads_and_vars_i, inputs_i, hidden_states_i, logits_i = single_core_graph(
                    is_training=True,
                    features=examples[i],
                    label_list=label_list)

                tower_losses.append(loss_i)
                tower_grads_and_vars.append(grads_and_vars_i)
                tower_inputs.append(inputs_i)
                tower_hidden_states.append(hidden_states_i)
                tower_logits.append(logits_i)

        ## average losses and gradients across towers
        if len(tower_losses) > 1:
            loss = tf.add_n(tower_losses) / len(tower_losses)
            grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
            inputs = dict((n, tf.concat([t[n] for t in tower_inputs], 0))
                          for n in tower_inputs[0])
            hidden_states = list(
                tf.concat(t, 0) for t in zip(*tower_hidden_states))
            logits = tf.concat(tower_logits, 0)
        else:
            loss = tower_losses[0]
            grads_and_vars = tower_grads_and_vars[0]
            inputs = tower_inputs[0]
            hidden_states = tower_hidden_states[0]
            logits = tower_logits[0]

        # Summaries
        merged = tf.summary.merge_all()

        ## get train op
        train_op, learning_rate, gnorm = model_utils.get_train_op(
            FLAGS, None, grads_and_vars=grads_and_vars)
        global_step = tf.train.get_global_step()

        ##### Training loop
        saver = tf.train.Saver(max_to_keep=FLAGS.max_save)

        gpu_options = tf.GPUOptions(allow_growth=True)

        #### load pretrained models
        model_utils.init_from_checkpoint(FLAGS, global_vars=True)

        writer = tf.summary.FileWriter(logdir=FLAGS.model_dir,
                                       graph=tf.get_default_graph())
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())

            #########
            ##### PYTORCH
            import torch
            from torch.optim import Adam
            from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification, BertAdam

            save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME + '-00')
            saver.save(sess, save_path)
            tf.logging.info("Model saved in path: {}".format(save_path))

            device = torch.device("cuda", 4)
            config = XLNetConfig.from_pretrained('xlnet-large-cased',
                                                 finetuning_task=u'sts-b',
                                                 num_labels=1)
            tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

            # pt_model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased', num_labels=1)
            pt_model = XLNetForSequenceClassification.from_pretrained(
                save_path, from_tf=True, config=config)
            pt_model.to(device)
            pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])

            optimizer = Adam(pt_model.parameters(),
                             lr=0.001,
                             betas=(0.9, 0.999),
                             eps=FLAGS.adam_epsilon,
                             weight_decay=FLAGS.weight_decay,
                             amsgrad=False)
            # optimizer = BertAdam(pt_model.parameters(), lr=FLAGS.learning_rate, t_total=FLAGS.train_steps, warmup=FLAGS.warmup_steps / FLAGS.train_steps,
            #                      eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay)
            ##### PYTORCH
            #########

            fetches = [
                loss, global_step, gnorm, learning_rate, train_op, merged,
                inputs, hidden_states, logits
            ]

            total_loss, total_loss_pt, prev_step, gnorm_pt = 0., 0., -1, 0.0
            total_logits = None
            total_labels = None
            while True:
                feed_dict = {}
                # for i in range(FLAGS.num_core_per_host):
                #   for key in tower_mems_np[i].keys():
                #     for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                #       feed_dict[m] = m_np

                fetched = sess.run(fetches)

                loss_np, curr_step, gnorm_np, learning_rate_np, _, summary_np, inputs_np, hidden_states_np, logits_np = fetched
                total_loss += loss_np

                if total_logits is None:
                    total_logits = logits_np
                    total_labels = inputs_np['label_ids']
                else:
                    total_logits = np.append(total_logits, logits_np, axis=0)
                    total_labels = np.append(total_labels,
                                             inputs_np['label_ids'],
                                             axis=0)

                #########
                ##### PYTORCH
                f_inp = torch.tensor(inputs_np["input_ids"],
                                     dtype=torch.long,
                                     device=device)
                f_seg_id = torch.tensor(inputs_np["segment_ids"],
                                        dtype=torch.long,
                                        device=device)
                f_inp_mask = torch.tensor(inputs_np["input_mask"],
                                          dtype=torch.float,
                                          device=device)
                f_label = torch.tensor(inputs_np["label_ids"],
                                       dtype=torch.float,
                                       device=device)

                # with torch.no_grad():
                #   _, hidden_states_pt, _ = pt_model.transformer(f_inp, f_seg_id, f_inp_mask)
                # logits_pt, _ = pt_model(f_inp, token_type_ids=f_seg_id, input_mask=f_inp_mask)

                pt_model.train()
                outputs = pt_model(f_inp,
                                   token_type_ids=f_seg_id,
                                   input_mask=f_inp_mask,
                                   labels=f_label)
                loss_pt = outputs[0]
                loss_pt = loss_pt.mean()
                total_loss_pt += loss_pt.item()

                # # hidden_states_pt = list(t.detach().cpu().numpy() for t in hidden_states_pt)
                # # special_pt = special_pt.detach().cpu().numpy()

                # # Optimizer pt
                pt_model.zero_grad()
                loss_pt.backward()
                gnorm_pt = torch.nn.utils.clip_grad_norm_(
                    pt_model.parameters(), FLAGS.clip)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rate_np
                optimizer.step()
                ##### PYTORCH
                #########

                if curr_step > 0 and curr_step % FLAGS.log_step_count_steps == 0:
                    curr_loss = total_loss / (curr_step - prev_step)
                    curr_loss_pt = total_loss_pt / (curr_step - prev_step)
                    tf.logging.info(
                        "[{}] | gnorm {:.2f} lr {:8.6f} "
                        "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
                            curr_step, gnorm_np, learning_rate_np, curr_loss,
                            math.exp(curr_loss), curr_loss / math.log(2)))

                    #########
                    ##### PYTORCH
                    tf.logging.info(
                        "  PT [{}] | gnorm PT {:.2f} lr PT {:8.6f} "
                        "| loss PT {:.2f} | pplx PT {:>7.2f}, bpc PT {:>7.4f}".
                        format(curr_step, gnorm_pt, learning_rate_np,
                               curr_loss_pt, math.exp(curr_loss_pt),
                               curr_loss_pt / math.log(2)))
                    ##### PYTORCH
                    #########

                    total_loss, total_loss_pt, prev_step = 0., 0., curr_step
                    writer.add_summary(summary_np, global_step=curr_step)

                if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
                    save_path = os.path.join(FLAGS.model_dir,
                                             "model.ckpt-{}".format(curr_step))
                    saver.save(sess, save_path)
                    tf.logging.info(
                        "Model saved in path: {}".format(save_path))

                    #########
                    ##### PYTORCH
                    # Save a trained model, configuration and tokenizer
                    model_to_save = pt_model.module if hasattr(
                        pt_model,
                        'module') else pt_model  # Only save the model it-self
                    # If we save using the predefined names, we can load using `from_pretrained`
                    output_dir = os.path.join(
                        FLAGS.output_dir, "pytorch-ckpt-{}".format(curr_step))
                    if not tf.gfile.Exists(output_dir):
                        tf.gfile.MakeDirs(output_dir)
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    tf.logging.info(
                        "PyTorch Model saved in path: {}".format(output_dir))
                    ##### PYTORCH
                    #########

                if curr_step >= FLAGS.train_steps:
                    break

    if FLAGS.do_eval:
        # TPU requires a fixed batch size for all batches, therefore the number
        # of examples must be a multiple of the batch size, or else examples
        # will get dropped. So we pad with fake examples which are ignored
        # later on. These do NOT count towards the metric (all tf.metrics
        # support a per-instance weight, and these get a weight of 0.0).
        #
        # Modified in XL: We also adopt the same mechanism for GPUs.
        while len(eval_examples) % FLAGS.eval_batch_size != 0:
            eval_examples.append(PaddingInputExample())

        eval_file_base = "{}.len-{}.{}.eval.tf_record".format(
            spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
        eval_file = os.path.join(FLAGS.output_dir, eval_file_base)

        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenize_fn, eval_file)

        assert len(eval_examples) % FLAGS.eval_batch_size == 0
        eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=True)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())

            ########################### LOAD PT model
            #   import torch
            #   from pytorch_transformers import CONFIG_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, XLNetTokenizer, XLNetConfig, XLNetForSequenceClassification, BertAdam

            #   save_path = os.path.join(FLAGS.model_dir, TF_WEIGHTS_NAME)
            #   saver.save(sess, save_path)
            #   tf.logging.info("Model saved in path: {}".format(save_path))

            #   device = torch.device("cuda", 4)
            #   config = XLNetConfig.from_pretrained('xlnet-large-cased', finetuning_task=u'sts-b', num_labels=1)
            #   tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
            #   config_path = os.path.join(FLAGS.model_dir, CONFIG_NAME)
            #   config.to_json_file(config_path)
            #   # pt_model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased', num_labels=1)
            #   pt_model = XLNetForSequenceClassification.from_pretrained(FLAGS.model_dir, from_tf=True)
            #   pt_model.to(device)
            #   pt_model = torch.nn.DataParallel(pt_model, device_ids=[4, 5, 6, 7])
            #   from torch.optim import Adam
            #   optimizer = Adam(pt_model.parameters(), lr=0.001, betas=(0.9, 0.999),
            #                    eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay,
            #                    amsgrad=False)
            #   optimizer = BertAdam(pt_model.parameters(), lr=FLAGS.learning_rate, t_total=FLAGS.train_steps, warmup=FLAGS.warmup_steps / FLAGS.train_steps,
            #                        eps=FLAGS.adam_epsilon, weight_decay=FLAGS.weight_decay)

            ##### PYTORCH
            #########

            fetches = [
                loss, global_step, gnorm, learning_rate, train_op, merged,
                inputs, hidden_states, logits
            ]

            total_loss, total_loss_pt, prev_step, gnorm_pt = 0., 0., -1, 0.0
            total_logits = None
            total_labels = None
            while True:
                feed_dict = {}
                # for i in range(FLAGS.num_core_per_host):
                #   for key in tower_mems_np[i].keys():
                #     for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
                #       feed_dict[m] = m_np

                fetched = sess.run(fetches)

                loss_np, curr_step, gnorm_np, learning_rate_np, _, summary_np, inputs_np, hidden_states_np, logits_np = fetched
                total_loss += loss_np

                if total_logits is None:
                    total_logits = logits_np
                    total_labels = inputs_np['label_ids']
                else:
                    total_logits = np.append(total_logits, logits_np, axis=0)
                    total_labels = np.append(total_labels,
                                             inputs_np['label_ids'],
                                             axis=0)