예제 #1
0
def main(args):
    random_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda")
    device = torch.device("cuda" if args.cuda else "cpu")

    corpus = data.Corpus(args.data)
    ntokens = len(corpus.dictionary)
    print('loaded dictionary')
    if args.model == 'Transformer':
        model = TransformerModel(
            ntokens,
            args.emsize,
            args.nhead,
            args.nhid,
            args.nlayers,
            args.dropout).to(device)
    else:
        model = RNNModel(
            args.model,
            ntokens,
            args.emsize,
            args.nhid,
            args.nlayers,
            args.dropout,
            args.tied).to(device)

    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print('loaded model')

    is_transformer_model = hasattr(
        model, 'model_type') and model.model_type == 'Transformer'
    if not is_transformer_model:
        hidden = model.init_hidden(1)
    input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
    with open(args.outf, 'w') as outf:
        with torch.no_grad():  # no tracking history
            for i in range(args.words):
                if is_transformer_model:
                    output = model(input, False)
                    word_weights = output[-1].squeeze().div(
                        args.temperature).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0]
                    word_tensor = torch.Tensor([[word_idx]]).long().to(device)
                    input = torch.cat([input, word_tensor], 0)
                else:
                    output, hidden = model(input, hidden)
                    word_weights = output.squeeze().div(args.temperature).exp().cpu()
                    word_idx = torch.multinomial(word_weights, 1)[0]
                    input.fill_(word_idx)

                word = corpus.dictionary.idx2word[word_idx]

                outf.write(word + ('\n' if i % 20 == 19 else ' '))

                if i % args.log_interval == 0:
                    print('| Generated {}/{} words'.format(i, args.words))
예제 #2
0
int_to_vocab = vocab["int_to_vocab"]

ntokens = len(vocab_to_int)
emsize = 512
nhid = 512
nlayers = 4
nhead = 4
dropout = 0.2

model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers,
                         dropout).to(device)

model_save_path = "./models/transformer/lm-siamzone-v4-space-342.pkl"
model.load_state_dict(
    torch.load(model_save_path, map_location=torch.device("cpu")))
model.eval()

print("Model initialized")


def top_k_top_p_filtering(logits,
                          top_k,
                          top_p,
                          temperature,
                          filter_value=-float("Inf")):
    # Hugging Face script to apply top k and nucleus sampling
    logits = logits / temperature

    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
예제 #3
0
def main(args):
    random_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda")
    device = torch.device("cuda" if args.cuda else "cpu")

    corpus = data.Corpus(args.data)
    ntokens = len(corpus.dictionary)
    word2idx = corpus.dictionary.word2idx
    idx2word = corpus.dictionary.idx2word
    args.vocab_size = len(word2idx)
    print('loaded dictionary')

    if args.model == 'Transformer':
        model = TransformerModel(
            ntokens,
            args.emsize,
            args.nhead,
            args.nhid,
            args.nlayers,
            args.dropout).to(device)
    else:
        model = RNNModel(
            args.model,
            ntokens,
            args.emsize,
            args.nhid,
            args.nlayers,
            args.dropout,
            args.tied).to(device)

    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    is_transformer_model = hasattr(
        model, 'model_type') and model.model_type == 'Transformer'
    print('loaded model')

    input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

    # get as starting words only most common starting word
    # from data corpus(heuristics from baseline)
    most_common_first_words_ids = [i[0] for i in Counter(corpus.train.tolist()).most_common()
                                   if idx2word[i[0]][0].isupper()][:200]
#     most_common_first_words = [corpus.dictionary.idx2word[i]
#                                for i in most_common_first_words_ids]

    # private message(binary code)
    bit_stream = open(args.bit_stream_path, 'r').readline()
    outfile = open(args.save_path + 'generated' +
                   str(args.bit_num) + '_bit.txt', 'w')
    bitfile = open(args.save_path + 'bitfile_' +
                   str(args.bit_num) + '_bit.txt', 'w')
    bit_index = random.randint(0, len(word2idx))
    soft = torch.nn.Softmax(0)

    for uter_id, uter in tqdm.tqdm(
            enumerate(range(args.utterances_to_generate))):
        #         with torch.no_grad():  # no tracking history
        input_ = torch.LongTensor([random.choice(
            most_common_first_words_ids)]).unsqueeze(0).to(device)
        if not is_transformer_model:
            hidden = model.init_hidden(1)

        output, hidden = model(input_, hidden)
        gen = np.random.choice(len(corpus.dictionary), 1,
                               p=np.array(soft(output.reshape(-1)).tolist()) /
                               sum(soft(output.reshape(-1)).tolist()))[0]
        gen_res = list()
        gen_res.append(idx2word[gen])
        bit = ""
        for word_id, word in enumerate(range(args.len_of_generation - 2)):
            if is_transformer_model:
                assert NotImplementedError
            else:
                output, hidden = model(input_, hidden)
            p = output.reshape(-1)
            sorted_, indices = torch.sort(p, descending=True)
            words_prob = [(j, i) for i, j in
                          zip(sorted_[:2**int(args.bit_num)].tolist(),
                              indices[:2**int(args.bit_num)].tolist())]

            nodes = createNodes([item[1] for item in words_prob])
            root = createHuffmanTree(nodes)
            codes = huffmanEncoding(nodes, root)

            for i in range(2**int(args.bit_num)):
                if bit_stream[bit_index:bit_index + i + 1] in codes:
                    code_index = codes.index(
                        bit_stream[bit_index:bit_index + i + 1])
                    gen = words_prob[code_index][0]
                    test_data = np.int32(gen)
                    gen_res.append(idx2word[gen])
                    if idx2word[gen] in ['\n', '', "<eos>"]:
                        break
                    bit += bit_stream[bit_index: bit_index + i + 1]
                    bit_index = bit_index + i + 1
                    break

        gen_sen = ' '.join(
            [word for word in gen_res if word not in ["\n", "", "<eos>"]])
        outfile.write(gen_sen + "\n")
        bitfile.write(bit)
예제 #4
0
    for batch in pb:
        record_loss, perplexity = train_one_iter(batch, fp16=True)
        update_count += 1

        if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
            scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

            # speed measure
            end = time.time()
            speed = batch_size * num_gradients_accumulation / (end - start)
            start = end

            pb.set_postfix(loss=record_loss,
                           perplexity=perplexity,
                           speed=speed)

    "Evaluation"
    encoder.eval()
    decoder.eval()
    ppl = validate(val_dataloader)
    checkpointer.save_checkpoint(str(ep), {
        "encoder": encoder.state_dict(),
        "decoder": decoder.state_dict()
    }, {"empty": None},
                                 is_best_so_far=True)

    logger.info(f"a={a} b={b} Epoch {ep} Validation perplexity: {ppl}")

logger.info(f"Finish training of alpha={a} beta={b}")
예제 #5
0
def main():
    # from pathlib import Path
    # print("File      Path:", Path(__file__).absolute())
    # print("Directory Path:", Path().absolute())

    args = get_args()
    args.n_gpu = 1

    # noisy_sents_1 = read_strings(os.path.join(args.data_dir, "train_data", "train_data"))
    # clean_sents = read_strings(os.path.join(args.data_dir, "train_label"))
    # noisy_sents_2 = read_strings(os.path.join(args.data_dir, "train_data", "train_corpus"))
    #
    # noisy_sents = noisy_sents_1 + noisy_sents_2
    # noise_space_ratio = []
    #
    # for sentence in noisy_sents:
    #     noise_space_ratio.append(sentence.count(' ') / len(sentence))
    #
    # clean_space_ratio = []
    # for sentence in clean_sents:
    #     clean_space_ratio.append(sentence.count(' ') / len(sentence))
    #
    # print("noise_space_ratio: {}, clean_space_ratio: {}".format(sum(noise_space_ratio) / len(noise_space_ratio),
    #                                                             sum(clean_space_ratio) / len(clean_space_ratio)))

    # ##########
    # ##for local
    # args.num_workers=0
    # args.train_batch_size = 4
    # args.eval_batch_size = 4
    # args.eval_interval = 10
    # ##########

    set_seed(args)

    if args.tokenizer == 'char':
        tokenizer = CharTokenizer([])
    if args.tokenizer == 'kobert':
        print("koBERT tokenizer")
        tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert')
        args.vocab_size = tokenizer.vocab_size
        print(args.vocab_size)

    if args.load_vocab != "":
        tokenizer.load(args.load_vocab)
        args.vocab_size = tokenizer.__len__()

    logger.info(f"args: {json.dumps(args.__dict__, indent=2, sort_keys=True)}")

    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = TransformerModel(
        vocab_size=args.vocab_size,
        hidden_size=args.hidden_size,
        num_attention_heads=args.num_attention_heads,
        num_encoder_layers=args.num_encoder_layers,
        num_decoder_layers=args.num_decoder_layers,
        intermediate_size=args.intermediate_size,
        dropout=args.dropout,
    ).to(args.device)
    logger.info(
        f"# of model parameters: {sum(p.numel() for p in model.parameters()) * 1e-6:.2f}M"
    )

    eos_setting = args.eos_setting

    bind_nsml(model, tokenizer, args, eos=eos_setting)
    if args.pause:
        nsml.paused(scope=locals())

    if args.mode != 'test' and args.averaging != "":
        sess = 't0005/rush1-3/37'
        checkpoints = ["4500", "6500", "7500", "8000"]

        nsml.load(checkpoint=checkpoints[0], session=sess)
        args.vocab_size = tokenizer.__len__()
        print(args.vocab_size)

        model = TransformerModel(
            vocab_size=args.vocab_size,
            hidden_size=args.hidden_size,
            num_attention_heads=args.num_attention_heads,
            num_encoder_layers=args.num_encoder_layers,
            num_decoder_layers=args.num_decoder_layers,
            intermediate_size=args.intermediate_size,
            dropout=args.dropout,
        ).to(args.device)

        params = model.named_parameters()
        new_dict_params = dict(params)

        for checkpoint in checkpoints:
            bind_nsml(model, tokenizer, args, eos=eos_setting)
            nsml.load(checkpoint=checkpoint, session=sess)
            for name, param in params:
                new_dict_params[name] += param / len(checkpoints)

        model.load_state_dict(new_dict_params, strict=False)

        bind_nsml(model, tokenizer, args, eos=eos_setting)
        nsml.save('best')

    elif args.mode == 'eval':
        print("I'm in EVAL")

        checkpoint = 'best'
        sess = 't0005/rush1-3/507'
        nsml.load(checkpoint=checkpoint, session=sess)
        args.vocab_size = tokenizer.__len__()

        model = TransformerModel(
            vocab_size=args.vocab_size,
            hidden_size=args.hidden_size,
            num_attention_heads=args.num_attention_heads,
            num_encoder_layers=args.num_encoder_layers,
            num_decoder_layers=args.num_decoder_layers,
            intermediate_size=args.intermediate_size,
            dropout=args.dropout,
        ).to(args.device)

        bind_nsml(model, tokenizer, args, eos=eos_setting)
        nsml.load(checkpoint=checkpoint, session=sess)

        model.eval()
        #noisy_sents = open("./naver_data_clean.txt", "r", encoding='utf-8').read().splitlines()
        noisy_sents = read_strings(
            os.path.join(args.data_dir, "train_data", "train_corpus"))
        valid_noisy = noisy_sents[:1000]

        prediction = correct_beam(model,
                                  tokenizer,
                                  valid_noisy,
                                  args,
                                  eos=True,
                                  length_limit=0.15)

        for i, pred in enumerate(prediction[:1000]):
            print("noisy_input: {}, pred: {}".format(valid_noisy[i], pred))

        # bind_txt(prediction)
        # nsml.save('prediction')

        # with open('naver_data_clean_again.txt', 'w',encoding='utf-8') as f:
        #     for i, pred in enumerate(prediction):
        #         if i%500==0: print(i)
        #         f.write("%s\n" % pred)

    ## only works when char tokenizer
    ##TODO: kobert tokenizer, different vocabsize if it is needed
    elif args.mode != 'test' and args.resubmit != "":
        checkpoint = 'best'
        sess = 't0005/rush1-3/' + args.resubmit
        print(sess)

        model = None
        tokenizer = CharTokenizer([])
        bind_nsml(model, tokenizer, args, eos=eos_setting)
        nsml.load(checkpoint=checkpoint, session=sess)

        args.vocab_size = len(tokenizer)
        print(args.vocab_size)

        model = TransformerModel(
            vocab_size=args.vocab_size,
            hidden_size=args.hidden_size,
            num_attention_heads=args.num_attention_heads,
            num_encoder_layers=args.num_encoder_layers,
            num_decoder_layers=args.num_decoder_layers,
            intermediate_size=args.intermediate_size,
            dropout=args.dropout,
        ).to(args.device)

        bind_nsml(model, tokenizer, args, eos=eos_setting)
        nsml.load(checkpoint=checkpoint, session=sess)

        bind_nsml(model, tokenizer, args, eos=eos_setting)
        ########## testing loaded model & tokenizer ###############

        # model.eval()
        # noisy_sents = read_strings(os.path.join(args.data_dir, "train_data", "train_data"))
        # valid_noisy = noisy_sents[-10:]
        #
        # prediction = correct(model, tokenizer, valid_noisy, args, eos=True, length_limit=0.1)
        #
        # for pred in prediction:
        #     print(pred)

        ##################

        nsml.save("best")

    else:
        #train_data, valid_data = None, None
        if args.mode == "train" or args.mode == "pretrain" or args.mode == "semi-train":
            if args.mode == "train":
                # noisy_sents = open("./noisy_sejong_500k.txt", "r", encoding='utf-8').read().splitlines()[:20000]
                # clean_sents = open("./clean_sejong_500k.txt", "r", encoding='utf-8').read().splitlines()[:20000]
                # sents_annotation = ['None'] * len(noisy_sents)
                noisy_sents = read_strings(
                    os.path.join(args.data_dir, "train_data", "train_data"))
                sents_annotation = read_strings(
                    os.path.join(args.data_dir, "train_data",
                                 "train_annotation"))
                clean_sents = read_strings(
                    os.path.join(args.data_dir, "train_label"))

            if args.mode == "semi-train":
                noisy_sents = read_strings(
                    os.path.join(args.data_dir, "train_data", "train_data"))
                sents_annotation = read_strings(
                    os.path.join(args.data_dir, "train_data",
                                 "train_annotation"))
                clean_sents = read_strings(
                    os.path.join(args.data_dir, "train_label"))

                checkpoint = 'generated_data'
                sess = 't0005/rush1-1/' + str(args.semi_dataset)
                # five copy
                #sess = 't0005/rush1-1/209'
                # one copy
                #sess = 't0005/rush1-1/224'
                semi_noisy_sents, semi_clean_sents = load_generated_data(
                    checkpoint=checkpoint, session=sess)
                semi_sents_annotation = ['None'] * len(semi_noisy_sents)

            if args.mode == "pretrain":
                print("PRETRAIN MODE ON!!")
                noisy_sents = read_strings(
                    os.path.join('sejong_corpus', args.noisy_file))
                clean_sents = read_strings(
                    os.path.join('sejong_corpus', args.clean_file))
                # checkpoint = 'generated_data'
                # sess = 't0005/rush1-1/113'
                # noisy_sents, clean_sents = load_generated_data(checkpoint=checkpoint, session=sess)
                sents_annotation = ['None'] * len(noisy_sents)

            error_type_counter = Counter()

            for annotation in sents_annotation:
                error_type_counter += Counter(annotation.split(','))

            print(error_type_counter)

            # cleaning noise 버전
            # pairs = [{"noisy": preprocess_sentence(noisy), "clean": clean} for noisy, clean in zip(noisy_sents, clean_sents)]
            # original 버전

            if args.mode == "semi-train":
                pairs = [{
                    "noisy": noisy,
                    "clean": clean,
                    "annotation": annot
                }
                         for noisy, clean, annot in zip(
                             noisy_sents, clean_sents, sents_annotation)]
                semi_pairs = [{
                    "noisy": noisy,
                    "clean": clean,
                    "annotation": annot
                } for noisy, clean, annot in zip(
                    semi_noisy_sents, semi_clean_sents, semi_sents_annotation)]

                train_data = pairs[:-args.num_val_data] + semi_pairs
                valid_data = pairs[-args.num_val_data:]
                logger.info(f"# of train data: {len(train_data)}")
                logger.info(f"# of valid data: {len(valid_data)}")

                train_sents = [x['noisy'] for x in train_data
                               ] + [x['clean'] for x in train_data]
                tokenizer = CharTokenizer.from_strings(train_sents,
                                                       args.vocab_size)
                bind_nsml(model, tokenizer, args, eos=eos_setting)

            else:
                pairs = [{
                    "noisy": noisy,
                    "clean": clean,
                    "annotation": annot
                }
                         for noisy, clean, annot in zip(
                             noisy_sents, clean_sents, sents_annotation)]

                train_data, valid_data = train_test_split(
                    pairs, test_size=args.val_ratio,
                    random_state=args.seed)  # test: about 1000
                logger.info(f"# of train data: {len(train_data)}")
                logger.info(f"# of valid data: {len(valid_data)}")

                # print("validation: ", valid_data)

                train_sents = [x['noisy'] for x in train_data
                               ] + [x['clean'] for x in train_data]
                # train_sents = [x['clean'] for x in train_data]

                if args.load_model != "" and args.mode == "train":  # Load pretrained model
                    print("load pretrained model")
                    model.load_state_dict(
                        torch.load(args.load_model, map_location=args.device))

                    if args.freeze:
                        model.token_embeddings.weight.requires_grad = False
                        model.decoder_embeddings.weight.requires_grad = False

                if args.tokenizer == 'char' and args.load_vocab == "":
                    tokenizer = CharTokenizer.from_strings(
                        train_sents, args.vocab_size)
                    print(
                        f'tokenizer loaded from strings. len={len(tokenizer)}.'
                    )

                bind_nsml(model, tokenizer, args, eos=eos_setting)

                if args.tokenizer == 'char' and tokenizer is not None:
                    tokenizer.save('vocab.txt')

        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model, dim=1)

        if args.mode == "train" or args.mode == "pretrain" or args.mode == 'semi-train':
            train(model,
                  tokenizer,
                  train_data,
                  valid_data,
                  args,
                  eos=eos_setting)
예제 #6
0
def predict(dn, rn):
    dir_name_format = "../data/{dn}-{rn}-raw"
    dir_name = dir_name_format.format(dn=dn, rn=rn)
    input_path = os.path.join(dir_name, "src-test.txt")
    if not os.path.isfile(input_path):
        print(f"File: {input_path} not exist.")
        return

    output_filename = f"prediction-{dn}-{rn}.txt"
    output_path = os.path.join(outputDir, output_filename)
    if os.path.isfile(output_path):
        print(f"File {output_path} already exists.")
        return

    # 作用:将src进行index
    preprocess = IndexedInputTargetTranslationDataset.preprocess(source_dictionary)
    # 作用:将输出逆index为句子
    postprocess = lambda x: ''.join(
        [token for token in target_dictionary.tokenize_indexes(x) if token != END_TOKEN and token != START_TOKEN and token != PAD_TOKEN])
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() and not args.no_cuda else 'cpu')

    print('Building model...')
    model = TransformerModel(source_dictionary.vocabulary_size, target_dictionary.vocabulary_size,
                             config['d_model'],
                             config['nhead'],
                             config['nhid'],
                             config['nlayers'])
    model.eval()
    checkpoint_filepath = checkpoint_path
    checkpoint = torch.load(checkpoint_filepath, map_location='cpu')
    model.load_state_dict(checkpoint)
    translator = Translator(
        model=model,
        beam_size=args.beam_size,
        max_seq_len=args.max_seq_len,
        trg_bos_idx=target_dictionary.token_to_index(START_TOKEN),
        trg_eos_idx=target_dictionary.token_to_index(END_TOKEN)
    ).to(device)

    from utils.pipe import PAD_INDEX
    def pad_src(batch):
        sources_lengths = [len(sources) for sources in batch]
        sources_max_length = max(sources_lengths)
        sources_padded = [sources + [PAD_INDEX] * (sources_max_length - len(sources)) for sources in batch]
        sources_tensor = torch.tensor(sources_padded)
        return sources_tensor
    def process(seq):
        seq = seq.strip()
        def is_proof(name):
            return name.count("balance") > 0 or name.count("one") > 0
        if is_proof(data_name) and not is_proof(dn):
            seq += ",$,1"
            global is_proof_process
            if is_proof_process:
                print("processing")
                is_proof_process = False
        return seq

    batch_size = args.bs
    print(f"Output to {output_path}:")
    with open(output_path, 'w', encoding='utf-8') as outFile:
        with open(input_path, 'r', encoding='utf-8') as inFile:
            seqs = []
            for seq in tqdm(inFile):
                seq = process(seq)
                src_seq = preprocess(seq)
                seqs.append(src_seq)
                if len(seqs) >= batch_size:
                    pred_seq = translator.translate_sentence(pad_src(seqs).to(device))
                    pred_line = [postprocess(pred) for pred in pred_seq]
                    # print(pred_line)
                    outFile.writelines([p.strip() + '\n' for p in pred_line])
                    seqs.clear()
                # endif
            # endfor
            if seqs:    # last batch
                pred_seq = translator.translate_sentence(pad_src(seqs).to(device))
                pred_line = [postprocess(pred).replace(START_TOKEN, '').replace(END_TOKEN, '') for pred in pred_seq]
                # print(pred_line)
                outFile.writelines([p.strip() + '\n' for p in pred_line])
                seqs.clear()
        # endwith
    # endwith
    print(f'[Info] {input_path} Finished.')
예제 #7
0
class TrainLoop_Transformer():
    def __init__(self, opt):
        self.opt = opt

        self.dict = json.load(open(args.bpe2index, encoding='utf-8'))
        self.index2word = {self.dict[key]: key for key in self.dict}

        self.batch_size = self.opt['batch_size']
        self.epoch = self.opt['epoch']
        self.use_cuda = opt['use_cuda']
        print('self.use_cuda:', self.use_cuda)

        self.device = 'cuda:{}'.format(
            self.opt['gpu']) if self.use_cuda else 'cpu'
        self.opt['device'] = self.device

        self.movie_ids = pkl.load(open("data/movie_ids.pkl", "rb"))

        # self.metrics_gen = {
        #     "ppl": 0,
        #     "dist1": 0,
        #     "dist2": 0,
        #     "dist3": 0,
        #     "dist4": 0,
        #     "bleu1": 0,
        #     "bleu2": 0,
        #     "bleu3": 0,
        #     "bleu4": 0,
        #     "count": 0
        # }

        self.build_data()
        self.build_model()

        # self.init_optim(
        #     [p for p in self.model.parameters() if p.requires_grad],
        #     optim_states=states.get('optimizer'),
        #     saved_optim_type=states.get('optimizer_type')
        # )
        self.init_optim(
            [p for p in self.model.parameters() if p.requires_grad])

    def build_data(self):
        if self.opt['process_data']:
            self.train_dataset = dataset(
                "../../data/data1030/output/train_cut.pkl", self.opt, 'train')
            self.valid_dataset = dataset(
                "../../data/data1030/output/valid_cut.pkl", self.opt, 'valid')
            self.test_dataset = dataset(
                "../../data/data1030/output/test_cut.pkl", self.opt, 'test')

            self.train_processed_set = self.train_dataset.data_process(True)
            self.valid_processed_set = self.valid_dataset.data_process(True)
            self.test_processed_set = self.test_dataset.data_process(True)

            pickle.dump(self.train_processed_set,
                        open('data/train_processed_set.pkl', 'wb'))
            pickle.dump(self.valid_processed_set,
                        open('data/valid_processed_set.pkl', 'wb'))
            pickle.dump(self.test_processed_set,
                        open('data/test_processed_set.pkl', 'wb'))
            logger.info("[Save processed data]")
        else:
            try:
                self.train_processed_set = pickle.load(
                    open('data/train_processed_set.pkl', 'rb'))
                self.valid_processed_set = pickle.load(
                    open('data/valid_processed_set.pkl', 'rb'))
                self.test_processed_set = pickle.load(
                    open('data/test_processed_set.pkl', 'rb'))
            except:
                assert 1 == 0, "No processed data"
            logger.info("[Load processed data]")

    def build_model(self):
        self.model = TransformerModel(self.opt, self.dict)
        # todo
        if self.opt['embedding_type'] != 'random':
            pass

        if self.opt['load_dict'] is not None:
            logger.info('[ Loading existing model params from {} ]'
                        ''.format(self.opt['load_dict']))
            self.model.load_model(self.opt['load_dict'])

        if self.use_cuda:
            self.model.to(self.device)

    def train(self):
        losses = []
        best_val_gen = 1000
        gen_stop = False
        patience = 0
        max_patience = 5
        num = 0

        # file_temp = open('temp.txt', 'w')
        # train_output_file = open(f"output_train_tf.txt", 'w', encoding='utf-8')

        for i in range(self.epoch):
            train_set = CRSdataset(self.train_processed_set,
                                   self.opt['n_entity'], self.opt['n_concept'])
            train_dataset_loader = torch.utils.data.DataLoader(
                dataset=train_set, batch_size=self.batch_size,
                shuffle=True)  # shuffle

            for context,c_lengths,response,r_length,mask_response, \
                    mask_r_length,entity,entity_vector,movie,\
                    concept_mask,dbpedia_mask,concept_vec, \
                    db_vec,rec in tqdm(train_dataset_loader):
                ####################################### 检验输入输出ok
                # file_temp.writelines("[Context] ", self.vector2sentence(context))
                # file_temp.writelines("[Response] ", self.vector2sentence(response))
                # file_temp.writelines("\n")

                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)

                self.model.train()
                self.zero_grad()

                scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss= \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, seed_sets, movie, \
                        concept_vec, db_vec, entity_vector.to(self.device), rec, test=False)

                ##########################################
                # train_output_file.writelines(
                #     ["Loss per batch = %f\n" % gen_loss.item()])
                # train_output_file.writelines(['[GroundTruth] ' + ' '.join(sen_gt)+'\n' \
                #     + '[Generated] ' + ' '.join(sen_gen)+'\n\n' \
                #     for sen_gt, sen_gen in zip(self.vector2sentence(response.cpu()), self.vector2sentence(preds.cpu()))])

                losses.append([gen_loss])
                self.backward(gen_loss)
                self.update_params()

                if num % 50 == 0:
                    loss = sum([l[0] for l in losses]) / len(losses)
                    ppl = exp(loss)
                    logger.info('gen loss is %f, ppl is %f' % (loss, ppl))
                    losses = []

                num += 1

            output_metrics_gen = self.val(epoch=i)
            _ = self.val(True, epoch=i)

            if best_val_gen < output_metrics_gen["ppl"]:
                patience += 1
                logger.info('Patience = ', patience)
                if patience >= 5:
                    gen_stop = True
            else:
                patience = 0
                best_val_gen = output_metrics_gen["ppl"]
                self.model.save_model(self.opt['model_save_path'])
                logger.info(
                    f"[generator model saved in {self.opt['model_save_path']}"
                    "------------------------------------------------]")

            if gen_stop:
                break

        # train_output_file.close()
        # _ = self.val(is_test=True)

    def val(self, is_test=False, epoch=-1):
        # count是response数量
        self.model.eval()
        if is_test:
            valid_processed_set = self.test_processed_set
        else:
            valid_processed_set = self.valid_processed_set

        val_set = CRSdataset(valid_processed_set, self.opt['n_entity'],
                             self.opt['n_concept'])
        val_dataset_loader = torch.utils.data.DataLoader(
            dataset=val_set, batch_size=self.batch_size, shuffle=False)

        inference_sum = []
        tf_inference_sum = []
        golden_sum = []
        # context_sum = []
        losses = []
        recs = []

        for context, c_lengths, response, r_length, mask_response, mask_r_length, \
                entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec \
                in tqdm(val_dataset_loader):
            with torch.no_grad():
                seed_sets = []
                batch_size = context.shape[0]
                for b in range(batch_size):
                    seed_set = entity[b].nonzero().view(-1).tolist()
                    seed_sets.append(seed_set)

                # 使用teacher force下的回复生成,
                _, tf_preds, _, _, gen_loss, mask_loss, info_db_loss, info_con_loss = \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, \
                        seed_sets, movie, concept_vec, db_vec, entity_vector.to(self.device), rec, test=False)

                # 使用greedy模式下的回复生成,限定maxlen=20?
                # todo
                scores, preds, rec_scores, rec_loss, _, mask_loss, info_db_loss, info_con_loss = \
                    self.model(context.to(self.device), response.to(self.device), mask_response.to(self.device), concept_mask, dbpedia_mask, \
                        seed_sets, movie, concept_vec, db_vec, entity_vector.to(self.device), rec, test=True, maxlen=20, bsz=batch_size)

            golden_sum.extend(self.vector2sentence(response.cpu()))
            inference_sum.extend(self.vector2sentence(preds.cpu()))
            # tf_inference_sum.extend(self.vector2sentence(tf_preds.cpu()))
            # context_sum.extend(self.vector2sentence(context.cpu()))
            recs.extend(rec.cpu())
            losses.append(torch.mean(gen_loss))
            #logger.info(losses)
            #exit()

        subset = 'valid' if not is_test else 'test'

        # 原版: gen-loss来自teacher force,inference_sum来自greedy
        ppl = exp(sum(loss for loss in losses) / len(losses))
        output_dict_gen = {'ppl': ppl}
        logger.info(f"{subset} set metrics = {output_dict_gen}")
        # logger.info(f"{subset} set gt metrics = {self.metrics_gt}")

        # f=open('context_test.txt','w',encoding='utf-8')
        # f.writelines([' '.join(sen)+'\n' for sen in context_sum])
        # f.close()

        # 将生成的回复输出
        with open(f"output/output_{subset}_gen_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            f.writelines([
                '[Generated] ' + re.sub('@\d+', '__UNK__', ' '.join(sen)) +
                '\n' for sen in inference_sum
            ])

        # gt shuchu
        with open(f"output/output_{subset}_gt_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            for sen in golden_sum:
                mask_sen = re.sub('@\d+', '__UNK__', ' '.join(sen))
                mask_sen = re.sub(' ([!,.?])', '\\1', mask_sen)
                f.writelines(['[GT] ' + mask_sen + '\n'])

        # 将生成的回复与gt一起输出
        with open(f"output/output_{subset}_both_epoch_{epoch}.txt",
                  'w',
                  encoding='utf-8') as f:
            f.writelines(['[GroundTruth] ' + re.sub('@\d+', '__UNK__',' '.join(sen_gt))+'\n' \
                + '[Generated] ' + re.sub('@\d+', '__UNK__',' '.join(sen_gen))+'\n\n' \
                for sen_gt, sen_gen in zip(golden_sum, inference_sum)])

        self.save_embedding()

        return output_dict_gen

    def save_embedding(self):
        json.dump(loop.dict, open('output/tf_bpe2index.json', 'w'))

    def vector2sentence(self, batch_sen):
        # 一个batch的sentence 从id换成token
        sentences = []
        for sen in batch_sen.numpy().tolist():
            sentence = []
            for word in sen:
                if word > 3:
                    sentence.append(self.index2word[word])
                elif word == 3:
                    sentence.append('_UNK_')
            sentences.append(sentence)
        return sentences

    @classmethod
    def optim_opts(self):
        """
        Fetch optimizer selection.

        By default, collects everything in torch.optim, as well as importing:
        - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim

        Override this (and probably call super()) to add your own optimizers.
        """
        # first pull torch.optim in
        optims = {
            k.lower(): v
            for k, v in optim.__dict__.items()
            if not k.startswith('__') and k[0].isupper()
        }
        try:
            import apex.optimizers.fused_adam as fused_adam
            optims['fused_adam'] = fused_adam.FusedAdam
        except ImportError:
            pass

        try:
            # https://openreview.net/pdf?id=S1fUpoR5FQ
            from qhoptim.pyt import QHM, QHAdam
            optims['qhm'] = QHM
            optims['qhadam'] = QHAdam
        except ImportError:
            # no QHM installed
            pass
        logger.info(optims)
        return optims

    def init_optim(self, params, optim_states=None, saved_optim_type=None):
        """
        Initialize optimizer with model parameters.

        :param params:
            parameters from the model

        :param optim_states:
            optional argument providing states of optimizer to load

        :param saved_optim_type:
            type of optimizer being loaded, if changed will skip loading
            optimizer states
        """

        opt = self.opt

        # set up optimizer args
        lr = opt['learningrate']
        kwargs = {'lr': lr}
        # kwargs['amsgrad'] = True
        # kwargs['betas'] = (0.9, 0.999)

        optim_class = self.optim_opts()[opt['optimizer']]
        logger.info(f'optim_class = {optim_class}')
        self.optimizer = optim_class(params, **kwargs)

    def backward(self, loss):
        """
        Perform a backward pass. It is recommended you use this instead of
        loss.backward(), for integration with distributed training and FP16
        training.
        """
        loss.backward()

    def update_params(self):
        """
        Perform step of optimization, clipping gradients and adjusting LR
        schedule if needed. Gradient accumulation is also performed if agent
        is called with --update-freq.

        It is recommended (but not forced) that you call this in train_step.
        """
        update_freq = 1
        if update_freq > 1:
            # we're doing gradient accumulation, so we don't only want to step
            # every N updates instead
            self._number_grad_accum = (self._number_grad_accum +
                                       1) % update_freq
            if self._number_grad_accum != 0:
                return
        #0.1是不是太小了,原版就是这样
        if self.opt['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.opt['gradient_clip'])

        self.optimizer.step()

    def zero_grad(self):
        """
        Zero out optimizer.

        It is recommended you call this in train_step. It automatically handles
        gradient accumulation if agent is called with --update-freq.
        """
        self.optimizer.zero_grad()