def translate_rnn(line: str, line_number: int, s2s: S2S_basic.S2S or S2S_attention.S2S, src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict, device: torch.device): line = " ".join([src_vocab.start_token, line, src_vocab.end_token]) line = line.split() lang_token = line[1] assert lang_token.startswith("<") and lang_token.endswith(">") # inputs: (input_length,) inputs = torch.tensor([src_vocab.get_index(token) for token in line], device=device) # inputs: (input_length, 1) inputs = inputs.view(-1, 1) if lang_token.startswith("<") and lang_token.endswith(">"): # add language vector # input_embedding: (input_length, 1, embedding_size) # lang_encoding: (embedding_size, ) lang_encoding = torch.tensor(lang_vec[lang_token], device=device) input_embedding = s2s.encoder.embedding(inputs) + lang_encoding else: input_embedding = s2s.encoder.embedding(inputs) print("line {} does not add language embedding".format(line_number)) encoder_output, encoder_hidden_state = s2s.encoder.rnn(input_embedding) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) max_length = (inputs.size(0) - 2) * 3 pred_line = [] for i in range(max_length): # decoder_output: (1, 1, vocab_size) # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size) decoder_output, decoder_hidden_state = decode_batch( s2s, decoder_input, decoder_hidden_state, encoder_output) # pred: (1, 1) pred = torch.argmax(decoder_output, dim=2) if tgt_vocab.get_token(pred[0, 0].item()) == tgt_vocab.end_token: break decoder_input = pred pred_line.append(tgt_vocab.get_token(pred[0, 0].item())) return pred_line
def greedy_decoding_transformer(s2s: transformer.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, device: torch.device, tgt_prefix: List[str] = None): # src: (batch_size, input_length) src = data_tensor src_mask = s2s.make_src_mask(src) batch_size = src.size(0) encoder_src = s2s.encoder(src, src_mask) tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) tgt = tgt.expand(batch_size, -1) # pred_list: List[List] (tgt_length, batch_size) pred_list = [] if tgt_prefix is not None: # tgt_prefix_tensor: (batch_size, ) tgt_prefix = [ tgt_vocab.get_index(prefix_token) for prefix_token in tgt_prefix ] tgt_prefix_tensor = torch.tensor(tgt_prefix, device=device) # tgt_prefix_tensor: (batch_size, 1) tgt_prefix_tensor = tgt_prefix_tensor.unsqueeze(1) # tgt: (batch_size, 2) tgt = torch.cat([tgt, tgt_prefix_tensor], dim=1) pred_list.append(tgt_prefix) max_length = src.size(1) * 3 end_token_index = tgt_vocab.get_index(tgt_vocab.end_token) for i in range(0 if tgt_prefix is None else 1, max_length): # tgt: (batch_size, i + 1) tgt_mask = s2s.make_tgt_mask(tgt) # output: (batch_size, input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # output: (batch_size, vocab_size) output = output[:, -1, :] # pred: (batch_size, ) pred = torch.argmax(output, dim=-1) if torch.all(pred == end_token_index).item(): break tgt = torch.cat([tgt, pred.unsqueeze(1)], dim=1) pred_list.append(pred.tolist()) return convert_index_to_token(pred_list, tgt_vocab, batch_size, end_token_index)
def train(cfg): """ training begin :param cfg: config file :return: """ datasets = build_dataset(cfg) algo = TFIDFClustring(cfg) vocab = Vocab(cfg) summary = SummaryTxt(cfg) keyword = Keyword(cfg, summary) processed_news_num = 0 batch_size = cfg.SOLVER.BATCH_SIZE print('start training:') for seg_id in trange(0, datasets.file_num, batch_size): seg = [] for batch_idx in range(batch_size): batch, seg_size = datasets.getitem(seg_id + batch_idx) seg.extend(batch) processed_news_num += seg_size algo.run(segments=seg, vocab=vocab, seg_id=seg_id, keyword=keyword, summary=summary) # keyword.update_per_seg(new_updated_topic=new_updated_topic) print("seg idx: {}. processed news: {}".format(seg_id, processed_news_num)) pass
def convert_data_to_index(data: List[str], vocab: Vocab): data2index = [] for sentence in data: sentence = " ".join([vocab.start_token, sentence, vocab.end_token]) data2index.append( [vocab.get_index(token) for token in sentence.split()]) return data2index
def greedy_decoding_rnn(s2s: S2S_basic.S2S or S2S_attention.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, device: torch.device): # inputs: (input_length, batch_size) inputs = data_tensor batch_size = inputs.size(1) encoder_output, encoder_hidden_state = s2s.encoder(inputs) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) decoder_input = decoder_input.expand(-1, batch_size) max_length = inputs.size(0) * 3 # pred_list: List[List] (tgt_length, batch_size) pred_list = [] end_token_index = tgt_vocab.get_index(tgt_vocab.end_token) for i in range(max_length): # decoder_output: (1, batch_size, vocab_size) # decoder_hidden_state: (num_layers * num_directions, batch_size, hidden_size) decoder_output, decoder_hidden_state = decode_batch( s2s, decoder_input, decoder_hidden_state, encoder_output) # pred: (1, batch_size) pred = torch.argmax(decoder_output, dim=2) if torch.all(pred == end_token_index).item(): break decoder_input = pred pred_list.append(pred.squeeze(0).tolist()) return convert_index_to_token(pred_list, tgt_vocab, batch_size, end_token_index)
def load_corpus_data(data_path, language_name, start_token, end_token, mask_token, vocab_path, rebuild_vocab, unk="UNK", threshold=0): if rebuild_vocab: v = Vocab(language_name, start_token, end_token, mask_token, threshold=threshold) corpus = [] with open(data_path) as f: data = f.read().strip().split("\n") for line in data: line = line.strip() line = " ".join([start_token, line, end_token]) if rebuild_vocab: v.add_sentence(line) corpus.append(line) data2index = [] if rebuild_vocab: v.add_unk(unk) v.save(vocab_path) else: v = Vocab.load(vocab_path) for line in corpus: data2index.append([v.get_index(token) for token in line.split()]) return data2index, v
def convert_index_to_token(pred_list: List[List], tgt_vocab: Vocab, batch_size: int, end_token_index: int): # pred_line: List[List] (tgt_length, batch_size) pred_line = [] for j in range(batch_size): line = [] for i in range(len(pred_list)): if pred_list[i][j] == end_token_index: break line.append(tgt_vocab.get_token(pred_list[i][j])) pred_line.append(line) return pred_line
def beam_search_transformer(s2s: transformer.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, beam_size: int, device: torch.device): # src: (1, input_length) src = data_tensor src = src.expand(beam_size, -1) src_mask = s2s.make_src_mask(src) encoder_src = s2s.encoder(src, src_mask) max_length = src.size(1) * 3 # tgt: (1, 1) tgt = torch.tensor([[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) # tgt: (beam_size, 1) tgt = tgt.expand(beam_size, -1) scores = torch.zeros(beam_size, device=device) complete_seqs = [] complete_seqs_scores = [] step = 1 while True: tgt_mask = s2s.make_tgt_mask(tgt) # output: (1 * beam_size, input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # output: (1 * beam_size, vocab_size) output = output[:, -1, :] # output: (1 * beam_size, vocab_size) output = F.log_softmax(output, dim=-1) # sub_sentence_scores: (1 * beam_size, vocab_size) sub_sentence_scores = output + scores.unsqueeze(1) if step == 1: pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size, dim=-1) else: # sub_sentence_scores: (beam_size * vocab_size) sub_sentence_scores = sub_sentence_scores.view(-1) pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # next_tgt: (beam_size, input_length + 1) next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1) if step == max_length: complete_seqs.extend(next_tgt.tolist()) complete_seqs_scores.extend(pred_prob.tolist()) break complete_indices = [] for i, indices in enumerate(token_id): if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token: complete_indices.append(i) if len(complete_indices) > 0: complete_seqs.extend(next_tgt[complete_indices].tolist()) complete_pred_indices = beam_id[complete_indices] * len( tgt_vocab) + token_id[complete_indices] if step == 1: complete_seqs_scores.extend( sub_sentence_scores[0][complete_pred_indices].tolist()) if len(complete_indices) == beam_size: break sub_sentence_scores[0][complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores[0].topk( beam_size, dim=-1) else: complete_seqs_scores.extend( sub_sentence_scores[complete_pred_indices].tolist()) if len(complete_indices) == beam_size: break sub_sentence_scores[complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # next_tgt: (beam_size, input_length + 1) next_tgt = torch.cat([tgt[beam_id], token_id.unsqueeze(1)], dim=1) step += 1 tgt = next_tgt scores = pred_prob best_sentence_id = 0 for i in range(len(complete_seqs_scores)): if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]: best_sentence_id = i best_sentence = complete_seqs[best_sentence_id] best_sentence = [ tgt_vocab.get_token(index) for index in best_sentence[1:-1] ] return best_sentence
def beam_search_rnn(s2s: S2S_attention.S2S or S2S_basic.S2S, data_tensor: torch.tensor, tgt_vocab: Vocab, beam_size: int, device: torch.device): # batch_size == beam_size # inputs: (input_length, beam_size) inputs = data_tensor inputs = inputs.expand(-1, beam_size) encoder_output, encoder_hidden_state = s2s.encoder(inputs) # decoder_input: (1, beam_size) decoder_input = torch.tensor( [[tgt_vocab.get_index(tgt_vocab.start_token)]], device=device) decoder_input = decoder_input.expand(-1, beam_size) # decoder_hidden_state: (num_layers, beam_size, hidden_size) decoder_hidden_state = combine_bidir_hidden_state(s2s, encoder_hidden_state) max_length = inputs.size(0) * 3 scores = torch.zeros(beam_size, device=device) complete_seqs = [] complete_seqs_scores = [] step = 1 while True: # output: (1, beam_size, vocab_size) # decoder_hidden_state: (num_layers, beam_size, hidden_size) output, decoder_hidden_state = decode_batch( s2s, decoder_input[-1].unsqueeze(0), decoder_hidden_state, encoder_output) output = F.log_softmax(output, dim=-1) # sub_sentence_scores: (beam_size, vocab_size) sub_sentence_scores = scores.unsqueeze(1) + output.squeeze(0) if step == 1: pred_prob, pred_indices = sub_sentence_scores[0].topk(beam_size, dim=-1) else: # sub_sentence_scores: (beam_size * vocab_size) sub_sentence_scores = sub_sentence_scores.view(-1) pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) # beam_id: (beam_size, ) beam_id = pred_indices.floor_divide(len(tgt_vocab)) # token_id: (beam_size, ) token_id = pred_indices % len(tgt_vocab) # decoder_input[-1][beam_id]: (beam_size, ) # next_decoder_input: (step + 1, beam_size) # decoder_input: (step, beam_size) next_decoder_input = torch.cat( [decoder_input[:, beam_id], token_id.unsqueeze(0)], dim=0) if step == max_length: complete_seqs.extend(next_decoder_input.t().tolist()) complete_seqs_scores.extend(pred_prob.tolist()) break complete_indices = [] for i, indices in enumerate(token_id): if tgt_vocab.get_token(indices.item()) == tgt_vocab.end_token: complete_indices.append(i) if len(complete_indices) > 0: complete_seqs.extend( next_decoder_input[:, complete_indices].t().tolist()) complete_pred_indices = beam_id[complete_indices] * len( tgt_vocab) + token_id[complete_indices] if step == 1: complete_seqs_scores.extend( sub_sentence_scores[0][complete_pred_indices].tolist()) if len(complete_pred_indices) == beam_size: break sub_sentence_scores[0][complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores[0].topk( beam_size, dim=-1) else: complete_seqs_scores.extend( sub_sentence_scores[complete_pred_indices].tolist()) if len(complete_pred_indices) == beam_size: break sub_sentence_scores[complete_pred_indices] = -1e9 pred_prob, pred_indices = sub_sentence_scores.topk(beam_size, dim=-1) beam_id = pred_indices.floor_divide(len(tgt_vocab)) token_id = pred_indices % len(tgt_vocab) next_decoder_input = torch.cat( [decoder_input[:, beam_id], token_id.unsqueeze(0)], dim=0) step += 1 if isinstance(decoder_hidden_state, tuple): h, c = decoder_hidden_state h = h[:, beam_id] c = c[:, beam_id] decoder_hidden_state = (h, c) else: decoder_hidden_state = decoder_hidden_state[:, beam_id] decoder_input = next_decoder_input scores = pred_prob best_sentence_id = 0 for i in range(len(complete_seqs_scores)): if complete_seqs_scores[i] > complete_seqs_scores[best_sentence_id]: best_sentence_id = i best_sentence = complete_seqs[best_sentence_id] best_sentence = [ tgt_vocab.get_token(index) for index in best_sentence[1:-1] ] return best_sentence
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("-data_path", type=str, default="../data") parser.add_argument("-input_data", type=str, default="sections.pickle") parser.add_argument("-vocab", type=str, default="vocab.csv") parser.add_argument("-encoding", type=str, default="word2index") parser.add_argument("-tokenization", type=str, default="word_tokenize") parser.add_argument("-speaker_lables", type=bool, default=True) parser.add_argument("-lowercase", type=bool, default=True) parser.add_argument("-splitting", type=bool, default=True) parser.add_argument("-min_occ", type=int, default=3) args = parser.parse_args() vocab = Vocab.create(args.data_path, "train_" + args.input_data, args.vocab, args.tokenization, args.lowercase, args.splitting, args.min_occ) print("Creating new PhotoBook segment dataset...") chain_id = 0 for set_name in ['dev', 'val', 'test', 'train']: with open( os.path.join(args.data_path, set_name + "_" + args.input_data), 'rb') as f: dialogue_segments = pickle.load(f) segment_dataset, chain_dataset, chain_id = create_segment_datasets( dialogue_segments, vocab, chain_id, args.encoding, args.tokenization, args.speaker_lables, args.lowercase, args.splitting) with open(os.path.join(args.data_path, set_name + "_segments.json"),
def test_item_file(end_test_file, embedding_file_path, vocab_file_path, use_gpu): embed = torch.Tensor(np.load(embedding_file_path)['arr_0']) with open(vocab_file_path) as f: word2id = json.load(f) vocab = Vocab(embed, word2id) #with open(end_test_file) as f: # examples = [json.loads(line) for line in f] with open(end_test_file) as f: examples = list() for line in f: if line and not line.isspace(): examples.append(json.loads(line)) #print(examples[0]) test_dataset = Dataset(examples) test_iter = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False) load_dir = os.path.join(args.input, 'model_files', 'CNN_RNN.pt') if use_gpu: checkpoint = torch.load(load_dir) else: checkpoint = torch.load(load_dir, map_location=lambda storage, loc: storage) if not use_gpu: checkpoint['args'].device = None net = getattr(models, checkpoint['args'].model)(checkpoint['args']) net.load_state_dict(checkpoint['model']) if use_gpu: net.cuda() net.eval() doc_num = len(test_dataset) all_targets = [] all_results = [] all_probs = [] all_acc = [] all_p = [] all_r = [] all_f1 = [] all_sum = [] for batch in tqdm(test_iter): features, targets, summaries, doc_lens = vocab.make_features(batch) if use_gpu: probs = net(Variable(features).cuda(), doc_lens) else: probs = net(Variable(features), doc_lens) start = 0 for doc_id, doc_len in enumerate(doc_lens): doc = batch['doc'][doc_id].split('\n')[:doc_len] stop = start + doc_len prob = probs[start:stop] hyp = [] for _p, _d in zip(prob, doc): print(_p) print(_d) if _p > 0.5: hyp.append(_d) if len(hyp) > 0: print(hyp) all_sum.append("###".join(hyp)) else: all_sum.append('') all_targets.append(targets[start:stop]) all_probs.append(prob) start = stop file_path_elems = end_test_file.split('/') file_name = 'TR-' + file_path_elems[len(file_path_elems) - 1] with open(os.path.join(args.output, file_name), mode='w', encoding='utf-8') as f: for text in all_sum: f.write(text.strip() + '\n') for item in all_probs: all_results.append([1 if tmp > 0.5 else 0 for tmp in item.tolist()]) print(len(all_results)) print(len(all_targets)) print(len(all_probs)) for _1, _2, _3 in zip(all_results, all_targets, all_probs): _2 = _2.tolist() _3 = _3.tolist() print("*" * 3) print('probs : ', _3) print('results : ', _1) print('targets : ', _2) tmp_acc = accuracy_score(_1, _2) tmp_p = precision_score(_1, _2) tmp_r = recall_score(_1, _2) tmp_f1 = f1_score(_1, _2) print('acc : ', tmp_acc) print('p : ', tmp_p) print('r : ', tmp_r) print('f1 : ', tmp_f1) all_acc.append(tmp_acc) all_p.append(tmp_p) all_r.append(tmp_r) all_f1.append(tmp_f1) print('all dataset acc : ', np.mean(all_acc)) print('all dataset p : ', np.mean(all_p)) print('all dataset r : ', np.mean(all_r)) print('all dataset f1 : ', np.mean(all_f1)) print('all results length : ', len(all_results))
from utils.ChainDataset import ChainDataset from utils.Vocab import Vocab # Tests the SegmentDataset and ChainDataset classes if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("-data_path", type=str, default="../data") parser.add_argument("-segment_file", type=str, default="segments.json") parser.add_argument("-chains_file", type=str, default="val_chains.json") parser.add_argument("-vocab_file", type=str, default="vocab.csv") parser.add_argument("-vectors_file", type=str, default="vectors.json") parser.add_argument("-split", type=str, default="val") args = parser.parse_args() print("Loading the vocab...") vocab = Vocab(os.path.join(args.data_path, args.vocab_file), 3) print("Testing the SegmentDataset class initialization...") segment_val_set = SegmentDataset(data_dir=args.data_path, segment_file=args.segment_file, vectors_file=args.vectors_file, split=args.split) print("Testing the SegmentDataset class item getter...") print("Dataset contains {} segment samples".format(len(segment_val_set))) sample_id = 2 sample = segment_val_set[sample_id] print("Segment {}:".format(sample_id)) print("Image set: {}".format(sample["image_set"])) print("Target image index(es): {}".format(sample["targets"]))
def translate_transformer(line: str, line_number: int, s2s: transformer.S2S, src_vocab: Vocab, tgt_vocab: Vocab, lang_vec: dict, device: torch.device): line = " ".join([src_vocab.start_token, line, src_vocab.end_token]) line = line.split() max_length = (len(line) - 2) * 3 lang_token = line[1] # inputs: (input_length, ) src = torch.tensor([src_vocab.get_index(token) for token in line], device=device) # inputs: (1, input_length) src = src.view(1, -1) src_mask = s2s.make_src_mask(src) src = s2s.encoder.token_embedding(src) * s2s.encoder.scale # src: (1, input_length, d_model) src = s2s.encoder.pos_embedding(src) if lang_token.startswith("<") and lang_token.endswith(">"): # lang_encoding: (d_model, ) lang_encoding = torch.tensor(lang_vec[lang_token], device=device) src = src + lang_encoding else: print("line {} does not add language embedding".format(line_number)) for layer in s2s.encoder.layers: src, self_attention = layer(src, src_mask) del self_attention encoder_src = src tgt = None pred_line = [tgt_vocab.get_index(tgt_vocab.start_token)] for i in range(max_length): if tgt is None: tgt = torch.tensor([pred_line], device=device) tgt_mask = s2s.make_tgt_mask(tgt) # output: (1, tgt_input_length, vocab_size) output = s2s.decoder(tgt, encoder_src, tgt_mask, src_mask) # (1, tgt_input_length) pred = torch.argmax(output, dim=-1)[0, -1] if tgt_vocab.get_token(pred.item()) == tgt_vocab.end_token: break tgt = torch.cat([tgt, pred.unsqueeze(0).unsqueeze(1)], dim=1) pred_line.append(pred.item()) pred_line = [tgt_vocab.get_token(index) for index in pred_line[1:]] return pred_line
def main(): parser = argparse.ArgumentParser() parser.add_argument("--device", required=True) parser.add_argument("--load", required=True) parser.add_argument("--src_vocab_path", required=True) parser.add_argument("--tgt_vocab_path", required=True) parser.add_argument("--test_src_path", required=True) parser.add_argument("--test_tgt_path", required=True) parser.add_argument("--lang_vec_path", required=True) parser.add_argument("--translation_output", required=True) parser.add_argument("--transformer", action="store_true") args, unknown = parser.parse_known_args() src_vocab = Vocab.load(args.src_vocab_path) tgt_vocab = Vocab.load(args.tgt_vocab_path) src_data = read_data(args.test_src_path) lang_vec = load_lang_vec(args.lang_vec_path) # lang_vec: dict device = args.device print("load from {}".format(args.load)) if args.transformer: max_src_length = max(len(line) for line in src_data) + 2 max_tgt_length = max_src_length * 3 padding_value = src_vocab.get_index(src_vocab.mask_token) assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token) s2s = load_transformer(args.load, len(src_vocab), max_src_length, len(tgt_vocab), max_tgt_length, padding_value, device=device) else: s2s = load_model(args.load, device=device) s2s.eval() pred_data = [] for i, line in enumerate(src_data): pred_data.append( translate(line, i, s2s, src_vocab, tgt_vocab, lang_vec, device)) write_data(pred_data, args.translation_output) pred_data_tok_path = args.translation_output + ".tok" tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format( args.translation_output, pred_data_tok_path) call(tok_command, shell=True) bleu_calculation_command = "perl /data/rrjin/NMT/scripts/multi-bleu.perl {} < {}".format( args.test_tgt_path, pred_data_tok_path) call(bleu_calculation_command, shell=True)
def train(): print("*"*100) print("train begin") # use gpu use_gpu = args.device is not None if torch.cuda.is_available() and not use_gpu: print("WARNING: You have a CUDA device, should run with -device 0") if use_gpu: # set cuda device and seed torch.cuda.set_device(args.device) torch.cuda.manual_seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) numpy.random.seed(args.seed) os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) # 路径准备 embedding_file_path = os.path.join(args.project, "embedding.npz") vocab_file_path = os.path.join(args.project, "word2id.json") end_train_file = os.path.join(args.input, "train_files", "train.txt") train_files_dir = os.path.join(args.input, "train_files") # 合并同后缀文本文件 merge_same_suf_text_file(train_files_dir, end_train_file, '.txt') print('Loading vocab,train and val dataset.Wait a second,please') embed = torch.Tensor(np.load(embedding_file_path)['arr_0']) # embed = torch.Tensor(list(np.load(args.embedding))) with open(vocab_file_path) as f: word2id = json.load(f) vocab = Vocab(embed, word2id) with open(end_train_file) as f: examples = list() for line in tqdm(f): if line and not line.isspace(): examples.append(json.loads(line)) train_dataset = Dataset(examples) print(train_dataset[:1]) args.embed_num = embed.size(0) # 从embeding中读取维度 args.embed_dim = embed.size(1) # args.kernel_sizes = [int(ks) for ks in args.kernel_sizes.split(',')] net = getattr(models, args.model)(args, embed) if use_gpu: net.cuda() train_iter = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False) criterion = nn.BCELoss() params = sum(p.numel() for p in list(net.parameters())) / 1e6 print('#Params: %.1fM' % (params)) min_loss = float('inf') optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) net.train() t1 = time() for epoch in range(1, args.max_epoch + 1): print("*"*10, 'epoch ', str(epoch), '*'*50) for i, batch in enumerate(train_iter): print("*"*10, 'batch', i, '*'*10) features, targets, _, doc_lens = vocab.make_features(batch, args.seq_trunc) features, targets = Variable(features), Variable(targets.float()) if use_gpu: features = features.cuda() targets = targets.cuda() probs = net(features, doc_lens) loss = criterion(probs, targets) optimizer.zero_grad() loss.backward() clip_grad_norm(net.parameters(), args.max_norm) optimizer.step() net.save() print('Epoch: %2d Loss: %f' % (epoch, loss)) t2 = time() print('Total Cost:%f h' % ((t2 - t1) / 3600)) print("模型配置文件保存至输出文件夹")
def evaluation(local_rank, args): rank = args.nr * args.gpus + local_rank dist.init_process_group(backend="nccl", init_method=args.init_method, rank=rank, world_size=args.world_size) device = torch.device("cuda", local_rank) torch.cuda.set_device(device) # List[str] src_data = read_data(args.test_src_path) tgt_prefix_data = None if args.tgt_prefix_file_path is not None: tgt_prefix_data = read_data(args.tgt_prefix_file_path) max_src_len = max(len(line.split()) for line in src_data) + 2 max_tgt_len = max_src_len * 3 logging.info("max src sentence length: {}".format(max_src_len)) src_vocab = Vocab.load(args.src_vocab_path) tgt_vocab = Vocab.load(args.tgt_vocab_path) padding_value = src_vocab.get_index(src_vocab.mask_token) assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token) src_data = convert_data_to_index(src_data, src_vocab) dataset = DataPartition(src_data, args.world_size, tgt_prefix_data, args.work_load_per_process).dataset(rank) logging.info("dataset size: {}, rank: {}".format(len(dataset), rank)) data_loader = DataLoader( dataset=dataset, batch_size=(args.batch_size if args.batch_size else 1), shuffle=False, pin_memory=True, drop_last=False, collate_fn=lambda batch: collate_eval( batch, padding_value, batch_first=(True if args.transformer else False))) if not os.path.isdir(args.translation_output_dir): os.makedirs(args.translation_output_dir) if args.beam_size: logging.info("Beam size: {}".format(args.beam_size)) if args.is_prefix: args.model_load = args.model_load + "*" for model_path in glob.glob(args.model_load): logging.info("Load model from: {}, rank: {}".format(model_path, rank)) if args.transformer: s2s = load_transformer(model_path, len(src_vocab), max_src_len, len(tgt_vocab), max_tgt_len, padding_value, training=False, share_dec_pro_emb=args.share_dec_pro_emb, device=device) else: s2s = load_model(model_path, device=device) s2s.eval() if args.record_time: import time start_time = time.time() pred_data = [] for data, tgt_prefix_batch in data_loader: if args.beam_size: pred_data.append( beam_search_decoding(s2s, data.to(device, non_blocking=True), tgt_vocab, args.beam_size, device)) else: pred_data.extend( greedy_decoding(s2s, data.to(device, non_blocking=True), tgt_vocab, device, tgt_prefix_batch)) if args.record_time: end_time = time.time() logging.info("Time spend: {} seconds, rank: {}".format( end_time - start_time, rank)) _, model_name = os.path.split(model_path) if args.beam_size: translation_file_name_prefix = "{}_beam_size_{}".format( model_name, args.beam_size) else: translation_file_name_prefix = "{}_greedy".format(model_name) p = os.path.join( args.translation_output_dir, "{}_translations.rank{}".format(translation_file_name_prefix, rank)) write_data(pred_data, p) if args.need_tok: # replace '@@ ' with '' p_tok = os.path.join( args.translation_output_dir, "{}_translations_tok.rank{}".format( translation_file_name_prefix, rank)) tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format( p, p_tok) call(tok_command, shell=True)
import json from collections import defaultdict from utils.Vocab import Vocab with open('data/test_segments.json', 'r') as file: test_sg = json.load(file) with open('data/test_chains.json', 'r') as file: test = json.load(file) vocab = Vocab('data/vocab.csv', 3) # given an img, provides the chains for which it was the target target2chains = defaultdict(list) for ch in test: target_id = ch['target'] segment_list = ch['segments'] target2chains[target_id].append(segment_list) id_list = [] # segments ids, in the order in which they were encountered in the chains in the whole dataset for c in test: segments = c['segments'] for s in segments: if s not in id_list:
parser.add_argument("--val_tgt_path", required=True) parser.add_argument("--src_vocab_path", required=True) parser.add_argument("--tgt_vocab_path", required=True) parser.add_argument("--picture_path", required=True) args, unknown = parser.parse_known_args() device = args.device s2s = load_model(args.load, device=device) if not isinstance(s2s, S2S_attention.S2S): raise Exception("The model don't have attention mechanism") src_vocab = Vocab.load(args.src_vocab_path) tgt_vocab = Vocab.load(args.tgt_vocab_path) with open(args.val_tgt_path) as f: data = f.read().split("\n") tgt_data = [normalizeString(line, to_ascii=False) for line in data] with torch.no_grad(): s2s.eval() with open(args.val_src_path) as f: data = f.read().split("\n")
print(args) # prepare datasets and obtain the arguments t = datetime.datetime.now() timestamp = str(t.date()) + '-' + str(t.hour) + '-' + str( t.minute) + '-' + str(t.second) seed = args.seed torch.manual_seed(seed) np.random.seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True print("Loading the vocab...") vocab = Vocab(os.path.join(args.data_path, args.vocab_file), 3) trainset = HistoryDataset(data_dir=args.data_path, segment_file='train_' + args.segment_file, vectors_file=args.vectors_file, chain_file='train_' + args.chains_file, split=args.split) testset = HistoryDataset(data_dir=args.data_path, segment_file='test_' + args.segment_file, vectors_file=args.vectors_file, chain_file='test_' + args.chains_file, split='test') valset = HistoryDataset(data_dir=args.data_path, segment_file='val_' + args.segment_file,
def __init__(self, cfg, summary): Vocab.__init__(self, cfg) self.file_path = cfg.OUTPUT_DIR self.summary = summary pass