def inference_random():
    # 加载验证集验证
    model = ClassificationModel(len(cfg.char2idx))
    model = load_custom_model(model, cfg.save_model_path).to(cfg.device)

    tokenizer = Tokenizer(cfg.char2idx)
    error = 0
    with open(cfg.test_data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    for line in lines:
        pairs = line.split('\t')
        label, text = pairs[0], pairs[1]
        input_index, _ = tokenizer.encode(text, max_length=cfg.max_seq_len)
        inputs = torch.tensor(input_index).unsqueeze(0)
        inputs_mask = (inputs > 0).to(torch.float32)
        with torch.no_grad():
            scores = model(inputs, inputs_mask)
            prediction = scores.argmax(-1).item()
        if prediction != int(label):
            print(scores[:, int(label)].item())
            print(label)
            print(text)
            print('-' * 50)
            error += 1
    print(error)
Example #2
0
def main(feature_type: str, language: str, domain: str, main_dir: str, seq_len: int,
         batch_size: int, lstm_dim: int, character_level: bool = False):
    """
    Parameters
    ----------
    feature_type: the name of the feature
    language: language of the text.
    main_dir: base directory
    seq_len: sequence length
    batch_size: batch size
    lstm_dim: lstm hidden dimension
    character_level: whether tokenizer should be on character level.
    """

    texts = get_texts(main_dir, language, feature_type, character_level, domain)

    tokenizer = Tokenizer(texts.values(), character_level=character_level)

    samples = {}

    for book in texts:
        print(len(texts[book]))
        len_text = len(texts[book]) if character_level else len(texts[book].split())

        if len_text < seq_len:
            logger.warn(f"Requested seq_len larger than text length: {len_text} / {seq_len} "
                             f"for {book} and feature type {feature_type}.")
            continue
        rand_idx = np.random.randint(0, len_text - seq_len, batch_size)

        if character_level:
            samples[book] = tokenizer.encode([texts[book][i: i + seq_len] for i in rand_idx])

        else:
            split_text = texts[book].split()
            samples[book] = tokenizer.encode(
                [" ".join(split_text[i: i + seq_len]) for i in rand_idx]
            )

    test_generator = DataGenerator(tokenizer,
                                   tokenizer.full_text,
                                   seq_len=seq_len,
                                   batch_size=batch_size,
                                   with_embedding=True,
                                   train=False)

    sample_batch = next(iter(test_generator))

    logger.info(f"X batch shape: {sample_batch[0].shape}, y batch shape: {sample_batch[1].shape}")
    logger.info(f"Sample batch text: {tokenizer.decode(sample_batch[0][0])}")

    file_path = os.path.join(main_dir, 'models',
                             f'{feature_type}_{language}_lstm_{lstm_dim}')

    if domain:
        file_path += '_' + domain

    if character_level:
        file_path += '_character_level'

    file_path += '.h5'

    logger.info(f"Loading {file_path}")

    prediction_model = lstm_model(num_words=tokenizer.num_words,
                                  lstm_dim=lstm_dim,
                                  seq_len=1,
                                  batch_size=batch_size,
                                  stateful=True,
                                  return_state=True)

    prediction_model.load_weights(file_path)

    hiddens = {}
    seeds = {}
    predictions = {}

    for book in samples:
        seed = np.stack(samples[book])
        print(seed.shape)
        hf, preds = generate_text(prediction_model, tokenizer, seed, get_hidden=True)
        print(hf.shape)
        hiddens[book] = hf
        seeds[book] = seed
        preds = [tokenizer.ix_to_word[pred] for pred in preds]
        predictions[book] = preds

    file_name = f'{feature_type}_{language}_lstm_{lstm_dim}_seq_len_{seq_len}'

    if domain:
        file_name += '_' + domain

    if character_level:
        file_name += '_character-level'
    file_name += '.pkl'

    path_out = os.path.join('data', 'hidden_states', file_name)
    with open(path_out, 'wb') as f:
        pickle.dump(hiddens, f)

    logger.info(f"Succesfully saved hidden dimensions to {path_out}")

    path_out = os.path.join('data', 'seeds', file_name)
    with open(path_out, 'wb') as f:
        pickle.dump(seeds, f)
    logger.info(f"Succesfully saved seeds to {path_out}")

    path_out = os.path.join('data', 'predictions', file_name)
    with open(path_out, 'wb') as f:
        pickle.dump(predictions, f)

    logger.info(f"Succesfully saved predictions to {path_out}")
Example #3
0
            else:
                with open("datasets/seq2seq/test.pkl", 'rb') as f:
                    data = pickle.load(f)
            
        mode = arg[7]
        batch_size = int(arg[6])
        l = len(data) 
        if l%batch_size==0:    
            bl = l//batch_size
        else:
            bl = l//batch_size+1
        
        data_batches = [data.collate_fn([data[j] for j in range(i*batch_size,min((i+1)*batch_size,l))]) for i in range(bl)]
        #print(data_batches[0]['text'])
        
        print(tokenizer.encode("."))
        mymodel = S2S(emb_w.size(0),emb_w.size(1),256,len(emb_w),device,layer=int(arg[8]),attention=attention).to(device)
        mymodel.embedding.from_pretrained(emb_w)
        if os.path.isfile(arg[4]):
            mymodel.load_state_dict(torch.load(arg[4],map_location= device))
        else:
            print("Model File Not Found.")
        #solver.test
        #result,ids = solver.test(mymodel,data_batches,device,tokenizer,attention=attention,batch_size=batch_size,mode=mode)
        with torch.no_grad():
            result,ids = solver.test_beam_search(mymodel,data_batches,device,tokenizer,beam_size=1,attention=attention,batch_size=batch_size,mode=mode)

        post = Postprocessing()
        dict_result = []
        dict_result = post.indiesToSentences(result,dict_result,ids,vocab,tokenizer,mode=mode)
        post.toJson(arg[5],dict_result)