Пример #1
0
class Evaluate(object):
    def __init__(self, model_file_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path, self.vocab, 'eval',
                               config.batch_size, single_pass=True)
        time.sleep(5)
        eval_dir = os.path.join(config.log_root, 'eval_%d'%(int(time.time())))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = tf.summary.FileWriter(eval_dir)

        self.model = Model(model_file_path, is_eval=True)

    def eval(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1 = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t, c_t, _, _  = self.model.decoder(y_t_1, s_t_1,
                                                          encoder_outputs, enc_padding_mask, c_t_1,
                                                          extra_zeros, enc_batch_extend_vocab)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data[0]

    def run_eval(self):
        start = time.time()
        running_avg_loss, iter = 0, 0
        batch = self.batcher.next_batch()
        while batch is not None:
            loss = self.eval(batch)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' % (
                iter, print_interval, time.time() - start, loss))
                start = time.time()

            batch = self.batcher.next_batch()
Пример #2
0
class Encode(object):
    def __init__(self, model_file_path, destination_dir):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.encode_data_path,
                               self.vocab,
                               mode='encode',
                               batch_size=config.batch_size,
                               single_pass=True)
        time.sleep(5)

        self.output = {}
        self.destination_dir = destination_dir
        self.model = Model(model_file_path, is_eval=True)

    def save_output(self, output, destination_dir):
        if destination_dir is None:
            torch.save(output, "output")
        else:
            torch.save(output, destination_dir)

    def encode_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)

        h, c = self.model.reduce_state(encoder_hidden)
        h, c = h.squeeze(0), c.squeeze(0)
        encodes = torch.cat((h, c), 1)

        for id, encode in zip(batch.original_abstracts, encodes):
            print(encode)
            self.output[id] = encode

    def run_encode(self):
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            self.encode_one_batch(batch)
            batch = self.batcher.next_batch()
        self.save_output(self.output, self.destination_dir)
Пример #3
0
def load_batches_decode():

    vocab   = Vocab(config.vocab_path, config.vocab_size)
    batcher = Batcher(config.decode_data_path, vocab, mode='decode',
                           batch_size=config.beam_size, single_pass=True)

    batches = [None for _ in range(TEST_DATA_SIZE)]
    for i in range(TEST_DATA_SIZE):
        batch = batcher.next_batch()
        batches[i] = batch

    with open("lib/data/batches_test.vocab{}.beam{}.pk.bin".format(vocab.size(), config.beam_size), "wb") as f:
        pickle.dump(batches, f)
Пример #4
0
def load_batches_train():

    vocab   = Vocab(config.vocab_path, config.vocab_size)
    batcher = Batcher(config.decode_data_path, vocab, mode='train',
                           batch_size=config.batch_size, single_pass=False)

    TRAIN_DATA_SIZE = 287226
    num_batches = int(TRAIN_DATA_SIZE / config.batch_size)
    batches = [None for _ in range(num_batches)]
    for i in tqdm(range(num_batches)):
        batch = batcher.next_batch()
        batches[i] = batch

    with open("lib/data/batches_train.vocab{}.batch{}.pk.bin".format(vocab.size(), config.batch_size), "wb") as f:
        pickle.dump(batches, f)
class Train(object):
    def __init__(self):
        #config("print.vocab_path ",config.vocab_path)
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        #print("params : ",params)
        #print("params collection is completed....")
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        #### Loading state where the training stopped earlier use that to train for future epoches ####
        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                ###### Making into GPU/server accessable Variables #####
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()
        return start_iter, start_loss

    def train_one_batch(self, batch):

        ########### Below Two lines of code is for just initialization of Encoder and Decoder sizes,vocab, lenghts etc : ######
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()
        #print("train_one_batch function ......")
        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(
            encoder_hidden
        )  ### Here initially encoder final hiddenstate==decoder first/prev word at timestamp=0
        #print("s_t_1 : ",len(s_t_1),s_t_1[0].shape,s_t_1[1].shape)

        #print("steps.....")
        #print("max_dec_len = ",max_dec_len)
        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            ############ Traing [ Teacher Forcing ] ###########
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            #print("y_t_1 : ",len(y_t_1))
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            #print("attn_dist : ",len(attn_dist),attn_dist[0].shape)
            #print("final_dist : ",len(final_dist),final_dist[0].shape) ############## vocab_Size
            target = target_batch[:, di]
            #print("target = ",len(target))

            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(
                gold_probs + config.eps
            )  #################################################### Eqn_6
            if config.is_coverage:
                step_coverage_loss = torch.sum(
                    torch.min(attn_dist, coverage),
                    1)  ###############################Eqn_13a
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss  ###############################Eqn_13b
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        print("trainIters__Started___model_file_path is : ", model_file_path)
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        print("Max iteration : n_iters = ", n_iters)
        print("going to start running iter NO : ", iter)
        print("\n******************************\n")
        while iter < n_iters:
            print("\n###################################\n")
            print("iter : ", iter)
            batch = self.batcher.next_batch()
            print("batch data loading : ", batch)
            loss = self.train_one_batch(batch)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            print("running_avg_loss : ", running_avg_loss)
            iter += 1
            if iter % 100 == 0:  ##100
                self.summary_writer.flush()
            print_interval = 100  #1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                start = time.time()
            if iter % 500 == 0:  ##5000
                self.save_model(running_avg_loss, iter)
Пример #6
0
class Train(object):
    def __init__(self, model_file_path=None):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        if not model_file_path:
            train_dir = os.path.join(config.log_root,
                                     'train_%d' % (int(time.time())))
            if not os.path.exists(train_dir):
                os.mkdir(train_dir)
        else:
            train_dir = re.sub('/model/model.*', '', model_file_path)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.create_file_writer(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)
        # self.optimizer = Adam(params)
        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def f(self, x, alpha):
        # # 1 - x ** alpha
        # k = utils.EPOCH / (utils.MAX_EPOCH / 2) - 1
        # return k * x + (1 - k)/2
        return 1 - x**alpha

    def get_loss_mask(self, src, tgt, absts, alpha=config.alpha):
        loss_mask = []
        for i in range(len(src)):

            # debug('src[i]',src[i])
            # debug('tgt[i]',src[i])
            # cnt = 0
            # tgt_i = [t for t in tgt[i] if t != 1]
            # src_i = set([s for s in src[i] if s != 1])
            # debug('src_i',src_i)
            # m = [t for t in tgt_i if t not in src_i ]
            # # for token in tgt_i:
            # #     if token not in src_i:
            # #         cnt += 1
            # cnt = len(m)
            # abst = round(cnt / len(tgt_i),4)
            abst = absts[i]
            loss_factor = self.f(abst, alpha)
            loss_mask.append(loss_factor)
        return torch.Tensor(loss_mask).cuda()

    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        # debug(batch.original_articles[0])
        # debug(batch.original_abstracts[0])
        loss_mask = self.get_loss_mask(enc_batch, dec_batch, batch.absts)
        # debug('loss_mask',loss_mask)
        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage, tau = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)

            # debug('enc_batch',enc_batch.size())
            # debug('dec_batch',dec_batch.size())
            # debug('final_dist', final_dist.size())
            # debug('target',target)
            # debug('gold_probs',gold_probs)

            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            # debug('step_loss_before',step_loss)
            # debug('config.loss_mask',config.loss_mask)
            if config.loss_mask:
                step_loss = step_loss * loss_mask
                # pass
            # debug('step_loss_after',step_loss)
            step_losses.append(step_loss)

            if config.DEBUG:
                # break
                pass

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        if not config.DEBUG:
            loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item(), tau

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()

        start_iter = iter
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss, tau = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if config.DEBUG:
                debug('iter', iter)
                if iter - start_iter > config.BREAK_POINT:
                    break

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 100
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                if config.adaptive_sparsemax:
                    print('tau + eps', [
                        round(e[0], 4)
                        for e in (tau +
                                  config.eps).detach().cpu().numpy().tolist()
                    ])
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)
Пример #7
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)

        self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
                               batch_size=config.batch_size, single_pass=False)
        # print("MODE MUST BE train")
        # time.sleep(15)
        self.print_interval = config.print_interval

        train_dir = config.train_dir
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = train_dir
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        # self.summary_writer = tf.compat.v1.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir, 'iter{}.pt'.format(iter))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        if not config.is_hierarchical:
            encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
            s_t_1 = self.model.reduce_state.forward1(encoder_hidden)

        else:
            stop_id = self.vocab.word2id('.')
            pad_id  = self.vocab.word2id('[PAD]')
            enc_sent_pos = get_sent_position(enc_batch, stop_id, pad_id)
            dec_sent_pos = get_sent_position(dec_batch, stop_id, pad_id)

            encoder_outputs, encoder_feature, encoder_hidden, sent_enc_outputs, sent_enc_feature, sent_enc_hidden, sent_enc_padding_mask, sent_lens, seq_lens2 = \
                                                                    self.model.encoder(enc_batch, enc_lens, enc_sent_pos)

            s_t_1, sent_s_t_1 = self.model.reduce_state(encoder_hidden, sent_enc_hidden)
        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            if not config.is_hierarchical:
                # start = datetime.now()

                final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder.forward1(y_t_1, s_t_1,
                                                            encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                            extra_zeros, enc_batch_extend_vocab,
                                                                               coverage, di)
                # print('NO HIER Time: ',datetime.now() - start)
                # import pdb; pdb.set_trace()
            else:
                # start = datetime.now()
                max_doc_len = enc_batch.size(1)
                final_dist, sent_s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, sent_s_t_1,
                                                            encoder_outputs, encoder_feature, enc_padding_mask, seq_lens2,
                                                            sent_s_t_1, sent_enc_outputs, sent_enc_feature, sent_enc_padding_mask,
                                                            sent_lens, max_doc_len,
                                                            c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)
                # print('DO HIER Time: ',datetime.now() - start)
                # import pdb; pdb.set_trace()


            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        # start = datatime.now()
        loss.backward()
        # print('{} HIER Time: {}'.format(config.is_hierarchical ,datetime.now() - start))
        # import pdb; pdb.set_trace()

        clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        sys.stdout.flush()

        # data_path = "lib/data/batches_train.vocab50000.batch16.pk.bin"
        # with open(data_path, 'rb') as f:
        #     stored_batches = pickle.load(f, encoding="bytes")
        # print("loaded data: {}".format(data_path))
        # num_batches = len(stored_batches)

        while iter < n_iters:
            batch = self.batcher.next_batch()
            # batch_id = iter%num_batches
            # batch = stored_batches[batch_id]

            loss = self.train_one_batch(batch)

            # running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, iter)

            iter += 1

            # if iter % 100 == 0:
            #     self.summary_writer.flush()

            if iter % self.print_interval == 0:
                print("[{}] iter {}, loss: {:.5f}".format(str(datetime.now()), iter, loss))
                sys.stdout.flush()

            if iter % config.save_every == 0:
                self.save_model(running_avg_loss, iter)

        print("Finished training!")
Пример #8
0
class Evaluate(object):
    def __init__(self, data_path, opt, batch_size=config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=batch_size,
                               single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(
            os.path.join(config.save_model_path, self.opt.load_model))
        self.model.load_state_dict(checkpoint["model_dict"])
        # save the light version of picked model
        '''
        print(config.save_model_path,' ','light_'+self.opt.load_model)
        save_path = config.save_model_path + '/light_'+self.opt.load_model 
        print(save_path)
        T.save({
            "model_dict": self.model.state_dict(),
        }, save_path)      
        exit()
        ''' #-- end --

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = "test_" + loadfile.split(".")[0] + ".txt"

        with open(os.path.join("data", filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents=False):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(
                batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask,
                                       ct_e, extra_zeros,
                                       enc_batch_extend_vocab, self.model,
                                       start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                     batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)

            batch = self.batcher.next_batch()

        load_file = self.opt.load_model

        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents,
                                          article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents, avg=True)
        if self.opt.task == "test":
            print(load_file, "scores:", scores)
        else:
            rouge_l = scores["rouge-l"]["f"]
            print(load_file, "rouge_l:", "%.4f" % rouge_l)
            with open("test_rg.txt", "a") as f:
                f.write("\n" + load_file + " - rouge_l: " + str(rouge_l))
                f.close()
Пример #9
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.train_batcher = Batcher(config.train_data_path,
                                     self.vocab,
                                     hps=config.hps,
                                     single_pass=False)
        self.val_batcher = Batcher(config.eval_data_path,
                                   self.vocab,
                                   hps=config.hps,
                                   single_pass=False)

    def setup_train_generator(self, model_file_path=None):
        generator = Generator(num_embeddings=config.vocab_size,  # 4999
                              embedding_dim=config.emb_dim,  # 128
                              n_labels=config.vocab_size,  # 4999
                              pad_length=config.padding,  # 20
                              encoder_units=config.hidden_dim,  # 256
                              decoder_units=config.hidden_dim,  # 256
                              )
        model = generator.model()
        model.summary()
        model.compile(optimizer='adagrad',
                      lr=config.lr,
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])
        print('Generator Compiled.')

        try:
            model.fit_generator(generator=self.train_batcher.next_batch(),
                                samples_per_epoch=5,
                                validation_data=self.val_batcher.next_batch(),
                                callbacks=[cp],
                                verbose=1,
                                nb_val_samples=1,
                                nb_epoch=config.max_iterations)

        except KeyboardInterrupt as e:
            print('Generator training stopped early.')
        print('Generator training complete.')

    def setup_train_discriminator(self):
        model = Discriminator().model()

        model.summary()
        model.compile(optimizer=Adam(lr=0.0001, beta_1=0.5, beta_2=0.9),
                      lr=config.lr,
                      loss='binary_crossentropy',
                      )
        print('Discriminator Compiled.')

        try:
            model.fit_generator(generator=self.train_batcher.next_batch_discriminator(),
                                samples_per_epoch=5,
                                validation_data=self.val_batcher.next_batch_discriminator(),
                                callbacks=[cp],
                                verbose=1,
                                nb_val_samples=1,
                                nb_epoch=config.max_iterations)

        except KeyboardInterrupt as e:
            print('Discriminator training stopped early.')
        print('Discriminator training complete.')

    def setup_train_wgan_model(self):
        generator = Generator(num_embeddings=config.vocab_size,  # 4999
                              embedding_dim=config.emb_dim,  # 128
                              n_labels=config.vocab_size,  # 4999
                              pad_length=config.padding,  # 20
                              encoder_units=config.hidden_dim,  # 256
                              decoder_units=config.hidden_dim,  # 256
                              ).model()
        reconstructor = Reconstructor(num_embeddings=config.vocab_size,  # 4999
                                      embedding_dim=config.emb_dim,  # 128
                                      n_labels=config.vocab_size,  # 4999
                                      pad_length=config.padding,  # 20
                                      encoder_units=config.hidden_dim,  # 256
                                      decoder_units=config.hidden_dim,  # 256
                                      ).model()
        discriminator = Discriminator().model()
        wgan = WGAN(generator=generator,
                    reconstructor=reconstructor,
                    discriminator=discriminator,
                    )
        try:
            wgan.train(self.train_batcher.next_batch())
        except KeyboardInterrupt as e:
            print('WGAN training stopped early.')
        print('WGAN training complete.')
Пример #10
0
class Evaluate(object):
    def __init__(self, data_path, opt, batch_size=config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=batch_size,
                               single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        if T.cuda.is_available(): map_location = T.device('cuda')
        else: map_location = T.device('cpu')
        checkpoint = T.load(
            os.path.join(config.save_model_path, self.opt.load_model),
            map_location)
        self.model.load_state_dict(checkpoint["model_dict"])


#        mlflow.pytorch.save_model(self.model,config.save_model_path+'_2')
#        mlflow.pytorch.load_model(config.save_model_path+'_2')

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):
        filename = "test_" + loadfile.split(".")[0] + ".txt"

        with open(os.path.join("data", filename), "w") as f:
            for i in range(len(decoded_sents)):
                f.write("article: " + article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents=False):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(
                batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask,
                                       ct_e, extra_zeros,
                                       enc_batch_extend_vocab, self.model,
                                       start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                     batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)
                article_art_oovs = batch.art_oovs[i]

            #batch = self.batcher.next_batch()
            break

        load_file = self.opt.load_model  # just a model name

        #if print_sents:
        #    self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)
        Batcher.article_summary = decoded_sents[0]
        Batcher.oovs = " ".join(article_art_oovs)

        #        print('Article: ',article_sents[0], '\n==> Summary: [',decoded_sents[0],']\nOut of vocabulary: ', " ".join(article_art_oovs),'\nModel used: ', load_file)
        scores = 0  #rouge.get_scores(decoded_sents, ref_sents, avg = True)
        if self.opt.task == "test":
            print('Done.')
            #print(load_file, "scores:", scores)
        else:
            rouge_l = scores["rouge-l"]["f"]
Пример #11
0
class Evaluate(object):
    def __init__(self, model_file_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.eval_data_path, self.vocab, mode='eval',
                               batch_size=config.batch_size, single_pass=True)
        time.sleep(15)
        model_name = os.path.basename(model_file_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = tf.summary.FileWriter(eval_dir)

        self.model = Model(model_file_path, is_eval=True)

    def eval_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1,attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1,
                                                                encoder_outputs, enc_padding_mask, c_t_1,
                                                                extra_zeros, enc_batch_extend_vocab, coverage)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data[0]

    def run_eval(self):
        running_avg_loss, iter = 0, 0
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            loss = self.eval_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' % (
                iter, print_interval, time.time() - start, running_avg_loss))
                start = time.time()
            batch = self.batcher.next_batch()
class Train(object):
    def __init__(self):
        self.vocab = Vocab(args.vocab_path, args.vocab_size)
        sys.stdout.flush()
        self.batcher = Batcher(args.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=args.batch_size,
                               single_pass=False)
        time.sleep(15)
        vocab_size = self.vocab.size()
        self.model = BertLSTMModel(args.hidden_size, self.vocab.size(),
                                   args.max_dec_steps)
        # self.model = Seq2SeqLSTM(args.hidden_size, self.vocab.size(), args.max_dec_steps)
        if use_cuda:
            self.model = self.model.cuda()

        self.model_optimizer = torch.optim.Adam(self.model.parameters(),
                                                lr=args.lr)

        train_logs = os.path.join(args.logs, "train_logs")
        eval_logs = os.path.join(args.logs, "eval_logs")
        self.train_summary_writer = tf.summary.FileWriter(train_logs)
        self.eval_summary_writer = tf.summary.FileWriter(eval_logs)

    def trainOneBatch(self, batch):
        self.model_optimizer.zero_grad()
        loss = self.model(batch)
        loss.backward()

        clip_grad_norm_(self.model.parameters(), 5)
        # clip_grad_norm_(self.model.decoder.parameters(), 5)
        self.model_optimizer.step()
        return loss.item() / args.max_dec_steps

    def trainIters(self):
        running_avg_loss = 0
        s_time = time.time()
        for t in range(args.max_iteration):
            batch = self.batcher.next_batch()
            loss = self.trainOneBatch(batch)
            running_avg_loss = calc_running_avg_loss(loss,
                                                     running_avg_loss,
                                                     decay=0.999)
            save_running_avg_loss(running_avg_loss, t,
                                  self.train_summary_writer)

            # Print every 100 steps
            if (t + 1) % 1 == 0:
                time_run = time.time() - s_time
                s_time = time.time()
                print("timestep: {}, loss: {}, time: {}s".format(
                    t, running_avg_loss, time_run))
                sys.stdout.flush()

            # Save the model every 1000 steps
            if (t + 1) % args.save_every_itr == 0:
                with torch.no_grad():
                    self.save_checkpoint(t, running_avg_loss)
                    self.model.eval()
                    self.evaluate(t)
                    self.model.train()

    def save_checkpoint(self, step, loss):
        checkpoint_file = "checkpoint_{}".format(step)
        checkpoint_path = os.path.join(args.logs, checkpoint_file)
        torch.save(
            {
                'timestep': step,
                'model_state_dict': self.model.decoder.state_dict(),
                'optimizer_state_dict': self.model_optimizer.state_dict(),
                'loss': loss
            }, checkpoint_path)
        # torch.save({
        #     'timestep': step,
        #     'encoder_state_dict': self.model.encoder.state_dict(),
        #     'decoder_state_dict': self.model.decoder.state_dict(),
        #     'optimizer_state_dict': self.model_optimizer.state_dict(),
        #     'loss': loss
        # }, checkpoint_path)

    def evaluate(self, timestep):
        self.eval_batcher = Batcher(args.eval_data_path,
                                    self.vocab,
                                    mode='train',
                                    batch_size=args.batch_size,
                                    single_pass=True)
        time.sleep(15)
        t1 = time.time()
        batch = self.eval_batcher.next_batch()
        running_avg_loss = 0
        while batch is not None:
            loss = self.model(batch)
            loss = loss / args.max_dec_steps
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss)
            batch = self.eval_batcher.next_batch()

        # Save the evaluation score
        time_spent = time.time() - t1
        print("Evaluation Loss: {}, Time: {}s".format(running_avg_loss,
                                                      time_spent))
        save_running_avg_loss(running_avg_loss, timestep,
                              self.eval_summary_writer)
        sys.stdout.flush()
Пример #13
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
                               batch_size=config.batch_size, single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'transformer_state_dict': self.model.state_dict(),
            'optimizer': self.optimizer._optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

        return model_save_path

    def setup_train(self, n_src_vocab, n_tgt_vocab, model_file_path=None):
        self.model = Model(n_src_vocab, n_tgt_vocab, config.max_article_len)

        params = list(self.model.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = ScheduledOptim(Adam(
            filter(lambda x: x.requires_grad, self.model.parameters()),
            betas=(0.9, 0.98), eps=1e-09),
            config.d_model, config.n_warmup_steps)
        #self.optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
        self.loss_func = torch.nn.CrossEntropyLoss(ignore_index = 1)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            self.model.load_state_dict(state['transformer_state_dict'])

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def get_pos_data(self, padding_masks):
        batch_size, seq_len = padding_masks.shape

        pos_data = [[ j + 1 if padding_masks[i][j] == 1 else 0 for j in range(seq_len)] for i in range(batch_size)]

        pos_data = torch.tensor(pos_data, dtype=torch.long)

        if use_cuda:
            pos_data = pos_data.cuda()

        return pos_data

    def train_one_batch(self, batch, iter):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
       # print(target_batch[:, 1:].contiguous().view(-1)[-10:])
        #print(dec_batch[:, 1:].contiguous().view(-1)[-10:])

        in_seq = enc_batch
        in_pos = self.get_pos_data(enc_padding_mask)
        tgt_seq = dec_batch
        tgt_pos = self.get_pos_data(dec_padding_mask)
        
        # padding is already done in previous function (see batcher.py - init_decoder_seq & init_decoder_seq - Batch class)
        self.optimizer.zero_grad()
        #logits = self.model.forward(in_seq, in_pos, tgt_seq, tgt_pos)
        logits = self.model.forward(in_seq, in_pos, tgt_seq, tgt_pos, extra_zeros, enc_batch_extend_vocab)

        # compute loss from logits
        loss = self.loss_func(logits, target_batch.contiguous().view(-1))

    

        # target_batch[torch.gather(logits, 2, target_batch.unsqueeze(2)).squeeze(2) == 0] = 1
        # target_batch = target_batch.contiguous().view(-1)
        # logits = logits.reshape(-1, logits.size()[2])
        # print(target_batch)
        # print('\n')
        # print(logits.size(), target_batch.size())
        # print('\n')
        #loss = self.loss_func(logits, target_batch)

        #print(loss)
        #sum_losses = torch.mean(torch.stack(losses, 1), 1)

        if iter % 50 == 0 and False:
            print(iter, loss)
            print('\n')
            # print(logits.max(1)[1][:20])
            # print('\n')
            # print(target_batch.contiguous().view(-1)[:20])
            # print('\n')
            #print(target_batch.contiguous().view(-1)[-10:])

        loss.backward()

        #print(logits.max(1)[1])
        #print('\n')
        #print(tgt_seq[:, 1:].contiguous().view(-1)[:10])
        #print(tgt_seq[:, 1:].contiguous().view(-1)[-10:])
        
        self.norm = clip_grad_norm_(self.model.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.parameters(), config.max_grad_norm)

        #self.optimizer.step()
        self.optimizer.step_and_update_lr()

        return loss.item()

    def trainIters(self, n_src_vocab, n_tgt_vocab, n_iters, model_file_path=None):

        print("Setting up the model...")

        iter, running_avg_loss = self.setup_train(n_src_vocab, n_tgt_vocab, model_file_path)

        print("Starting training...")
        print("Data for this model will be stored in", self.model_dir)
        
        start = time.time()

        #only_batch = None
        losses = []
        iters = []
        save_name = os.path.join(self.model_dir, "loss_lists")

        while iter < n_iters:
            batch = self.batcher.next_batch()
            
            # if iter == 0:
            #     only_batch = batch
            # else:
            #     batch = only_batch

            loss = self.train_one_batch(batch, iter)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1
            
            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 50
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval,
                                                                           time.time() - start, loss))
                start = time.time()
                
                iters.append(iter)
                losses.append(loss)

                with open(save_name, 'wb') as f:
                    pickle.dump((losses, iters), f)
            
            if iter % 5000 == 0:
                path = self.save_model(running_avg_loss, iter)

                print("Saving Checkpoint at {}".format(path))
Пример #14
0
class Evaluate(object):
    def __init__(self, data_path, opt, batch_size = config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='eval',
                               batch_size=batch_size, single_pass=True)
        self.opt = opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(os.path.join(config.save_model_path, self.opt.load_model))
        self.model.load_state_dict(checkpoint["model_dict"])

    # self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)
    def print_original_predicted(self, decoded_sents, ref_sents, article_sents, loadfile):
        # filename = "test_"+loadfile.split(".")[0]+".txt"
        filename = "E:\\4160\\TextSummarization\\data\\test_output.txt"
        print(f"file name {filename}")
        # with open(os.path.join("data",filename), "w") as f:
        with open(filename, "w+", encoding="utf-8") as f:
            for i in range(len(decoded_sents)):
                f.write("article: "+article_sents[i] + "\n")
                f.write("ref: " + ref_sents[i] + "\n")
                f.write("dec: " + decoded_sents[i] + "\n\n")

    def evaluate_batch(self, print_sents = True):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()

        # start_time = time.time()
        # test_batch = 0
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = get_enc_data(batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            #-----------------------Summarization----------------------------------------------------
            with T.autograd.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)

            # test_batch_time = time.time() - start_time
            # start_time = time.time()
            # print("Current test batch", test_batch)
            # print("Testing time", test_batch_time)
            # test_batch += 1

            batch = self.batcher.next_batch()

        load_file = self.opt.load_model

        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents, avg = True)
        if self.opt.task == "test":
            print(load_file, "scores:", scores)
        else:
            rouge_l = scores["rouge-l"]["f"]
            print(load_file, "rouge_l:", "%.4f" % rouge_l)
Пример #15
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.ouput_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.makedirs(train_dir)

        self.checkpoint_dir = os.path.join(train_dir, 'checkpoints')
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.train_summary_writer = tf.summary.create_file_writer(
            os.path.join(train_dir, 'log', 'train'))
        self.eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(train_dir, 'log', 'eval'))

    def save_model(self, model_path, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        torch.save(state, model_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(device, model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.to(device)

        return start_iter, start_loss

    def train_one_batch(self, batch, forcing_ratio=1):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, device)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, device)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        y_t_1_hat = None
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            # decide the next input
            if di == 0 or random.random() < forcing_ratio:
                x_t = y_t_1  # teacher forcing, use label from last time step as input
            else:
                # use embedding of UNK for all oov word
                y_t_1_hat[y_t_1_hat > self.vocab.size()] = self.vocab.word2id(
                    UNKNOWN_TOKEN)
                x_t = y_t_1_hat.flatten(
                )  # use prediction from last time step as input
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                x_t, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di)
            _, y_t_1_hat = final_dist.data.topk(1)
            target = target_batch[:, di].unsqueeze(1)
            step_loss = cal_NLLLoss(target, final_dist)
            if config.is_coverage:  # if not using coverge, keep coverage=None
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:,
                                         di]  # padding in target should not count into loss
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def train(self, n_iters, init_model_path=None):
        iter, avg_loss = self.setup_train(init_model_path)
        start = time.time()
        cnt = 0
        best_model_path = None
        min_eval_loss = float('inf')
        while iter < n_iters:
            s = config.forcing_ratio
            k = config.decay_to_0_iter
            x = iter
            nere_zero = 0.0001
            if config.forcing_decay_type:
                if x >= config.decay_to_0_iter:
                    forcing_ratio = 0
                elif config.forcing_decay_type == 'linear':
                    forcing_ratio = s * (k - x) / k
                elif config.forcing_decay_type == 'exp':
                    p = pow(nere_zero, 1 / k)
                    forcing_ratio = s * (p**x)
                elif config.forcing_decay_type == 'sig':
                    r = math.log((1 / nere_zero) - 1) / k
                    forcing_ratio = s / (1 + pow(math.e, r * (x - k / 2)))
                else:
                    raise ValueError('Unrecognized forcing_decay_type: ' +
                                     config.forcing_decay_type)
            else:
                forcing_ratio = config.forcing_ratio
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch, forcing_ratio=forcing_ratio)
            model_path = os.path.join(self.checkpoint_dir,
                                      'model_step_%d' % (iter + 1))
            avg_loss = calc_avg_loss(loss, avg_loss)

            if (iter + 1) % config.print_interval == 0:
                with self.train_summary_writer.as_default():
                    tf.summary.scalar(name='loss', data=loss, step=iter)
                self.train_summary_writer.flush()
                logger.info('steps %d, took %.2f seconds, train avg loss: %f' %
                            (iter + 1, time.time() - start, avg_loss))
                start = time.time()
            if config.eval_interval is not None and (
                    iter + 1) % config.eval_interval == 0:
                start = time.time()
                logger.info("Start Evaluation on model %s" % model_path)
                eval_processor = Evaluate(self.model, self.vocab)
                eval_loss = eval_processor.run_eval()
                logger.info(
                    "Evaluation finished, took %.2f seconds, eval loss: %f" %
                    (time.time() - start, eval_loss))
                with self.eval_summary_writer.as_default():
                    tf.summary.scalar(name='eval_loss',
                                      data=eval_loss,
                                      step=iter)
                self.eval_summary_writer.flush()
                if eval_loss < min_eval_loss:
                    logger.info(
                        "This is the best model so far, saving it to disk.")
                    min_eval_loss = eval_loss
                    best_model_path = model_path
                    self.save_model(model_path, eval_loss, iter)
                    cnt = 0
                else:
                    cnt += 1
                    if cnt > config.patience:
                        logger.info(
                            "Eval loss doesn't drop for %d straight times, early stopping.\n"
                            "Best model: %s (Eval loss %f: )" %
                            (config.patience, best_model_path, min_eval_loss))
                        break
                start = time.time()
            elif (iter + 1) % config.save_interval == 0:
                self.save_model(model_path, avg_loss, iter)
            iter += 1
        else:
            logger.info(
                "Training finished, best model: %s, with train loss %f: " %
                (best_model_path, min_eval_loss))
Пример #16
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root, 'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path, self.vocab, mode='decode',
                               batch_size=config.beam_size, single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)


    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(output_ids, self.vocab,
                                                 (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec'%(counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)


    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        if config.use_maxpool_init_ctx:
            c_t_0 = max_encoder_output

        dec_h, dec_c = s_t_0 # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                      log_probs=[0.0],
                      state=(dec_h[0], dec_c[0]),
                      context = c_t_0[0],
                      coverage=(coverage_t_0[0] if config.is_coverage else None))
                 for _ in xrange(config.beam_size)]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h =[]
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h, 0).unsqueeze(0), torch.stack(all_state_c, 0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab, coverage_t_1)

            topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in xrange(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in xrange(config.beam_size * 2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].data[0],
                                   log_prob=topk_log_probs[i, j].data[0],
                                   state=state_i,
                                   context=context_i,
                                   coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Пример #17
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
                               batch_size=config.batch_size, single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def save_best_so_far(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        best_dir = os.path.join(self.model_dir, 'best')
        model_save_path = os.path.join(best_dir, 'model_%d_%d' % (iter, int(time.time())))
        try:
            os.makedirs(best_dir)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise
        torch.save(state, model_save_path)

    def _change_lr(self, lr, it):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        lr_sum = tf.Summary()
        lr_sum.value.add(tag='lr', simple_value=lr)
        self.summary_writer.add_summary(lr_sum, it)

    def change_lr(self, it):
        it = it + 1
        new_lr = config.base * min(it ** -0.5, it * (config.warmup ** -1.5))
        self._change_lr(new_lr, it-1)

    def change_lr_lin(self, it):
        it_a = (config.anneal_steps - it) / config.anneal_steps
        lr_diff = config.start_lr - config.end_lr
        new_lr = config.end_lr  + lr_diff * max(0, it_a)
        self._change_lr(new_lr, it)

    def init_model(self):
        pass
#        for param in self.model.parameters():
#            init.uniform_(param, -0.2, 0.2)
#            init.normal_(param, 0, 0.15)
#            init.xavier_uniform


    def setup_optimizer(self):
        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr



#        self.optimizer = optim.Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
#        self.optimizer = optim.RMSprop(params, momentum=0.9)
        self.optimizer = optim.Adam(params) #, betas=(0.9, 0.98)) #, betas = (-0.5, 0.999))
#        self.optimizer = optim.Adamax(params)
#        self.optimizer = optim.Adam(params, lr=0.001)
#        self.optimizer = optim.SGD(params, lr=0.1, momentum=0.9, nesterov=True)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)
        print(self.model)

        self.init_model()
        self.setup_optimizer()

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if (not (config.is_coverage or config.scratchpad)) or config.load_optimizer_override:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch, it):
#        self.change_lr(it)

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)


        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            if config.scratchpad:
                final_dist, s_t_1, _, attn_dist, p_gen, encoder_outputs = \
                    self.model.decoder(
                        y_t_1, s_t_1, encoder_outputs, encoder_feature, \
                        enc_padding_mask, c_t_1, extra_zeros, \
                        enc_batch_extend_vocab, coverage, di \
                    )
            else:
                final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = \
                    self.model.decoder(
                        y_t_1, s_t_1, encoder_outputs, encoder_feature, \
                        enc_padding_mask, c_t_1, extra_zeros, \
                        enc_batch_extend_vocab, coverage, di \
                    )

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        if it % config.update_every == 0:
          self.optimizer.step()
          self.optimizer.zero_grad()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        best_loss = 20
        best_iter = 0
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch, iter)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1

#            is_new_best = (running_avg_loss < best_loss) and (iter - best_iter >= 100)
#            best_loss = min(running_avg_loss, best_loss)

            if iter % 20 == 0:
                self.summary_writer.flush()
            print_interval = 100
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval,
                                                                           time.time() - start, loss))
                start = time.time()
            if iter % 2500 == 0:
                self.save_model(running_avg_loss, iter)
            if loss < 2.0:
              self.save_best_so_far(running_avg_loss, iter)
Пример #18
0
class Summarizer(object):
    ''' Load with trained model and handle the beam search '''
    def __init__(self, opt):
        '''
        opt needs to contain:
            - model_file_path
            - n_best
            - max_token_seq_len
        '''
        self.opt = opt
        self.device = torch.device('cuda' if use_cuda else 'cpu')

        print("Max article len", config.max_article_len)
        model = Model(config.vocab_size, config.vocab_size,
                      config.max_article_len)

        checkpoint = torch.load(opt["model_file_path"],
                                map_location=lambda storage, location: storage)

        # model saved as:
        # state = {
        #     'iter': iter,
        #     'transformer_state_dict': self.model.state_dict(),
        #     'optimizer': self.optimizer.state_dict(),
        #     'current_loss': running_avg_loss
        # }

        model.load_state_dict(checkpoint['transformer_state_dict'])

        print('[Info] Trained model state loaded.')

        #model.word_prob_prj = nn.LogSoftmax(dim=1)

        self.model = model.to(self.device)

        self.model.eval()

        self._decode_dir = os.path.join(
            config.log_root,
            'decode_%s' % (opt["model_file_path"].split("/")[-1]))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.batch_size,
                               single_pass=True)

        time.sleep(15)

        print('[Info] Summarizer object created.')

    def summarize_batch(self, src_seq, src_pos):
        ''' Translation work in one batch '''
        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }

        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(src_seq, src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output,
                             inst_idx_to_position_map, n_bm):
            ''' Decode and update beam status, and then return active beam idx '''
            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
                dec_partial_pos = torch.arange(1,
                                               len_dec_seq + 1,
                                               dtype=torch.long,
                                               device=self.device)
                dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(
                    n_active_inst * n_bm, 1)
                return dec_partial_pos

            def predict_word(dec_seq, dec_pos, src_seq, enc_output,
                             n_active_inst, n_bm):
                dec_output, *_ = self.model.transformer.decoder(
                    dec_seq, dec_pos, src_seq, enc_output)

                # print("dec_output (line 136)", dec_output.size())
                # print("batch size", config.batch_size)
                # print("decoder output", dec_output[:, -1, :10])

                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
                logits = self.model.transformer.tgt_word_prj(dec_output)

                # print("logits size", logits.size())
                # print("logits", logits[:, :10])

                word_prob = logits  #F.softmax(logits, dim=1)

                # print(word_prob[:, :10])
                # print("word_prob", torch.max(word_prob, 1))

                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items(
                ):
                    is_inst_complete = inst_beams[inst_idx].advance(
                        word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)

            # print("first dec_seq", dec_seq)
            # print("first dec_pos", dec_pos)

            word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output,
                                     n_active_inst, n_bm)
            print(word_prob)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():
            #-- Encode
            # print("src_seq", src_seq.size())
            # print("src_pos", src_pos.size())

            #print("src_seq", src_seq[:, :20])
            #print("src_pos", src_pos[:, :20])

            src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
            src_enc, *_ = self.model.transformer.encoder(src_seq, src_pos)

            #-- Repeat data for beam search
            #n_bm = config.beam_size
            n_bm = 1
            n_inst, len_s, d_h = src_enc.size()

            # print("src_enc_shape", src_enc.size())

            src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s,
                                                      d_h)

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, device=self.device) for _ in range(n_inst)
            ]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            #-- Decode
            for len_dec_seq in range(1, self.opt["max_token_seq_len"] + 1):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc,
                    inst_idx_to_position_map, n_bm)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map,
                    active_inst_idx_list)

        batch_hyp, batch_scores = collect_hypothesis_and_scores(
            inst_dec_beams, self.opt["n_best"])

        print("-" * 50)

        return batch_hyp, batch_scores

    def get_pos_data(self, padding_masks):
        batch_size, seq_len = padding_masks.shape

        pos_data = [[
            j + 1 if padding_masks[i][j] == 1 else 0 for j in range(seq_len)
        ] for i in range(batch_size)]

        pos_data = torch.tensor(pos_data, dtype=torch.long)

        if use_cuda:
            pos_data = pos_data.cuda()

        return pos_data

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        #print(batch.enc_batch)

        while batch is not None:

            # Run beam search to get best Hypothesis
            enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = get_input_from_batch(
                batch, use_cuda)

            enc_batch = enc_batch[0:1, :]
            enc_padding_mask = enc_padding_mask[0:1, :]

            in_seq = enc_batch
            in_pos = self.get_pos_data(enc_padding_mask)
            #print("enc_padding_mask", enc_padding_mask)

            #print("Summarizing one batch...")

            batch_hyp, batch_scores = self.summarize_batch(in_seq, in_pos)

            # Extract the output ids from the hypothesis and convert back to words
            output_words = np.array(batch_hyp)
            output_words = output_words[:, 0, 1:]

            for i, out_sent in enumerate(output_words):

                decoded_words = data.outputids2words(
                    out_sent, self.vocab,
                    (batch.art_oovs[0] if config.pointer_gen else None))

                original_abstract_sents = batch.original_abstracts_sents[i]

                write_for_rouge(original_abstract_sents, decoded_words,
                                counter, self._rouge_ref_dir,
                                self._rouge_dec_dir)
                counter += 1

            if counter % 1 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
Пример #19
0
class Evaluate(object):
    def __init__(self, model_file_path):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.eval_data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=True)
        # time.sleep(15)
        model_name = os.path.basename(model_file_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = tf.summary.FileWriter(eval_dir)

        self.model = Model(model_file_path, is_eval=True)

    def eval_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch, target_ids_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            # target = target_batch[:, di]
            target_ids = target_ids_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target_ids.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.data[0]

    def run_eval(self):
        running_avg_loss, iter = 0, 0
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            loss = self.eval_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 10 == 0:
                self.summary_writer.flush()
            print_interval = 1
            if iter % print_interval == 0:
                print('iters = %d, time = %s , loss: %f' %
                      (iter, time_since(start), running_avg_loss))
                start = time.time()
            batch = self.batcher.next_batch()
Пример #20
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path, self.vocab, mode='train',
                               batch_size=config.batch_size, single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        loss = 0
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        # print("Encoding lengths", enc_lens)  #(1,8)
        # print("Encoding batch", enc_batch.size()) #(8, 400)
        # print("c_t_1 is ", c_t_1.size()) # (8, 512)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
        # print("Max Decoding lengths", max_dec_len)
        # print("Decoding lengths", dec_lens_var)
        # print("Decoding vectors", dec_batch[0])
        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        # print("encoder_outputs", encoder_outputs.size()) # (8, 400, 512)
        # print("encoder_feature", encoder_feature.size()) # (3200, 512)
        # print("encoder_hidden", encoder_hidden[1].size()) # (2, 8, 256)


        s_t_1 = self.model.reduce_state(encoder_hidden) # (1, 8, 256)
        # print("After reduce_state, the hidden state s_t_1 is", s_t_1[0].size()) 

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)): 
            # print("Decoder step = ", di)
            y_t_1 = dec_batch[:, di]  # Teacher forcing  #  the dith word of all the examples/targets in a batch of shape (8,)
            final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
                                                        encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                        extra_zeros, enc_batch_extend_vocab,
                                                                           coverage, di)

            # print("for di=", di, " final_dist", final_dist.size()) # (8, 50009)
            # print("for di=", di, " s_t_1", encoder_feature.size()) # (3200, 512)
            # print("for di=", di, " c_t_1", c_t_1.size()) # (8, 512)
            # print("for di=", di, " attn_dist", attn_dist.size()) # (8, 400)
            # print("for di=", di, " p_gen", p_gen.size()) # (8, 1)
            # print("for di=", di, " next_coverage", next_coverage.size())

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage
                
            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses/dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()
        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) # gradient clipping
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        print("Initial running average loss is", running_avg_loss)
        start = time.time()
        while iter < n_iters:
            print("Iteration: ", iter)
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 1 #1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' % (iter, print_interval,
                                                                           time.time() - start, loss))
                start = time.time()
            if iter % 1 == 0:#5000 == 0:
                self.save_model(running_avg_loss, iter)
Пример #21
0
class Evaluate(object):

    def __init__(self, data_path, opt, batch_size=config.batch_size):

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path, self.vocab, mode='eval', batch_size=batch_size,
                               single_pass=True)

        self.opt =opt
        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = torch.load(os.path.join(config.save_model_path,
                                             self.opt.load_model))
        # 加载在train中保存得模型
        self.model.load_state_dict(checkpoint['model_dict'])

    def print_original_predicted(self, decoded_sents, ref_sents, article_sents,
                                 loadfile):

        # 这里可能会存在一点问题,debug得时候剋注意一下
        filename = 'test_' + loadfile.split('.')[:-1] + '.txt'

        with open(os.path.join('data', filename), 'w') as f:
            for i in range(len(decoded_sents)):
                f.write('article' + article_sents[i] + '\n')
                f.write('reference:' + ref_sents[i] + '\n')
                f.write('decoder:' + decoded_sents[i] + '\n')

    def evaluate_batch(self, print_sents =False):

        self.setup_valid()
        batch = self.batcher.next_batch()

        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)

        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()

        batch_number = 0

        while batch is not None:

            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, \
                extra_zeros, ct_e = get_enc_data(batch)

            with torch.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            with torch.no_grad():
                pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e,
                                       extra_zeros, enc_batch_extend_vocab, self.model, start_id, end_id, unk_id)

            for i in range(len(pred_ids)):
                # 返回的是一个 单词列表。
                decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])

                if len(decoded_words) < 2:
                    decoded_words = 'xxx'
                else:
                    decoded_words = ' '.join(decoded_words)

                decoded_sents.append(decoded_words)
                summary = batch.original_summarys[i]
                article = batch.original_articles[i]
                ref_sents.append(summary)
                article_sents.append(article)

            batch = self.batcher.next_batch()
            batch_number += 1

            if batch_number < 100:
                continue
            else:
                break

        load_file = self.opt.load_model

        if print_sents:
            self.print_original_predicted(decoded_sents, ref_sents, article_sents, load_file)

        scores = rouge.get_scores(decoded_sents, ref_sents, avg=True)

        if self.opt.task == 'test':
            print(load_file, 'scores:', scores)
            sys.stdout.flush()
        else:
            rouge_l = scores['rouge-l']['f']
            print(load_file, 'rouge-l:', '%.4f' % rouge_l)
class Main(object):
    def __init__(self):
        self.vocab = Vocab(VOCAB_PATH, VOCAB_SIZE)
        self.batcher = Batcher(TRAIN_DATA_PATH, self.vocab, mode = 'train',batch_size = BATCH_SIZE, single_pass = False)
        self.start_id = self.vocab.word2id(data.START_DECODING)
        self.end_id = self.vocab.word2id(data.STOP_DECODING)
        self.pad_id = self.vocab.word2id(data.PAD_TOKEN)
        self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        self.model = MyModel().to(DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=LR)


    def beamSearch(self,enc_hid, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab):
        batch_size = len(enc_hid[0])
        beam_idx = torch.LongTensor(list(range(batch_size)))
        beams = [Beam(self.start_id, self.end_id, self.unk_id, (enc_hid[0][i], enc_hid[1][i]), ct_e[i]) for i in range(batch_size)]
        n_rem = batch_size
        sum_exp = None
        prev_s = None

        for t in range(MAX_DEC_STEPS):
            x_t = torch.stack(
                [beam.getTokens() for beam in beams if beam.done == False]  # remaining(rem),beam
            ).contiguous().view(-1)
            x_t = self.model.embeds(x_t)

            dec_h = torch.stack(
                [beam.hid_h for beam in beams if beam.done == False]
            ).contiguous().view(-1, HIDDEN_DIM)
            dec_c = torch.stack(
                [beam.hid_c for beam in beams if beam.done == False]
            ).contiguous().view(-1, HIDDEN_DIM)

            ct_e = torch.stack(
                [beam.context for beam in beams if beam.done == False]
            ).contiguous().view(-1, 2 * HIDDEN_DIM)

            if sum_exp is not None:
                sum_exp = torch.stack(
                    [beam.sum_exp for beam in beams if beam.done == False]
                ).contiguous().view(-1, enc_out.size(1))

            if prev_s is not None:
                prev_s = torch.stack(
                    [beam.prev_s for beam in beams if beam.done == False]
                )
                # try:
                prev_s = prev_s.contiguous().view(-1, t + 1, HIDDEN_DIM)
                # print(prev_s.shape)
                # except:
                #     continue
            s_t = (dec_h, dec_c)
            enc_out_beam = enc_out[beam_idx].view(n_rem, -1).repeat(1, BEAM_SIZE).view(-1, enc_out.size(1),
                                                                                              enc_out.size(2))
            enc_pad_mask_beam = enc_padding_mask[beam_idx].repeat(1, BEAM_SIZE).view(-1,
                                                                                            enc_padding_mask.size(1))

            extra_zeros_beam = None
            if extra_zeros is not None:
                extra_zeros_beam = extra_zeros[beam_idx].repeat(1, BEAM_SIZE).view(-1, extra_zeros.size(1))

            enc_extend_vocab_beam = enc_batch_extend_vocab[beam_idx].repeat(1, BEAM_SIZE).view(-1,enc_batch_extend_vocab.size(1))
            # print(enc_out_beam.shape)
            # try:
            final_dist, (dec_h, dec_c), ct_e, sum_exp, prev_s = self.model.decoder(x_t, s_t, enc_out_beam,enc_pad_mask_beam, ct_e,extra_zeros_beam,enc_extend_vocab_beam,sum_exp,prev_s)

            # except:
            #     continue
            # print(prev_s.shape)
            final_dist = final_dist.view(n_rem, BEAM_SIZE, -1)
            dec_h = dec_h.view(n_rem, BEAM_SIZE, -1)
            dec_c = dec_c.view(n_rem, BEAM_SIZE, -1)
            ct_e = ct_e.view(n_rem, BEAM_SIZE, -1)

            if sum_exp is not None:
                sum_exp = sum_exp.view(n_rem, BEAM_SIZE, -1)  # rem, beam, n_seq

            if prev_s is not None:
                prev_s = prev_s.view(n_rem, BEAM_SIZE, -1, HIDDEN_DIM)  # rem, beam, t
                # print("prev_s",prev_s.shape)

            active = []

            for i in range(n_rem):
                b = beam_idx[i].item()
                beam = beams[b]
                if beam.done:
                    continue

                sum_exp_i = prev_s_i = None
                if sum_exp is not None:
                    sum_exp_i = sum_exp[i]
                if prev_s is not None:
                    prev_s_i = prev_s[i]
                beam.advance(final_dist[i], (dec_h[i], dec_c[i]), ct_e[i], sum_exp_i, prev_s_i)
                if beam.done == False:
                    active.append(b)

            # print(len(active))
            if len(active) == 0:
                break

            beam_idx = torch.LongTensor(active)
            n_rem = len(beam_idx)

        predicted_words = []
        for beam in beams:
            predicted_words.append(beam.getBest())
        # print(len(predicted_words[0]))
        return predicted_words

    #获取encode的数据
    def getEncData(self,batch):
        batch_size = len(batch.enc_lens)
        enc_batch = torch.from_numpy(batch.enc_batch).long()
        enc_padding_mask = torch.from_numpy(batch.enc_padding_mask).float()

        enc_lens = batch.enc_lens

        ct_e = torch.zeros(batch_size, 2 * HIDDEN_DIM)

        enc_batch = enc_batch.to(DEVICE)
        enc_padding_mask = enc_padding_mask.to(DEVICE)

        ct_e = ct_e.to(DEVICE)

        enc_batch_extend_vocab = None
        if batch.enc_batch_extend_vocab is not None:
            enc_batch_extend_vocab = torch.from_numpy(batch.enc_batch_extend_vocab).long()
            enc_batch_extend_vocab = enc_batch_extend_vocab.to(DEVICE)

        extra_zeros = None
        if batch.max_art_oovs > 0:
            extra_zeros = torch.zeros(batch_size, batch.max_art_oovs)
            extra_zeros = extra_zeros.to(DEVICE)

        return enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e

    #获取decode数据
    def getDecData(self,batch):
        dec_batch = torch.from_numpy(batch.dec_batch).long()
        dec_lens = batch.dec_lens
        max_dec_len = np.max(dec_lens)
        dec_lens = torch.from_numpy(batch.dec_lens).float()

        target_batch = torch.from_numpy(batch.target_batch).long()

        dec_batch = dec_batch.to(DEVICE)
        dec_lens = dec_lens.to(DEVICE)
        target_batch = target_batch.to(DEVICE)

        return dec_batch, max_dec_len, dec_lens, target_batch

    #最大似然训练
    def trainMLEStep(self, batch):
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = self.getEncData(batch)
        # print(enc_lens)
        enc_batch = self.model.embeds(enc_batch)
        #将输入直接传入编码器
        enc_output,enc_hidden = self.model.encoder(enc_batch,enc_lens)
        # print(enc_output.shape)
        dec_batch, max_dec_len, dec_lens, target_batch = self.getDecData(batch)
        total_loss = 0
        s_t = (enc_hidden[0], enc_hidden[1])
        x_t =torch.LongTensor(len(enc_output)).fill_(self.start_id).to(DEVICE)
        # print(x_t.shape)
        prev_s = None
        sum_exp = None
        #最大解码步数,每次解码一个单词
        for t in range(min(max_dec_len, MAX_DEC_STEPS)):
            # print(max_dec_len)
            choice = (torch.rand(len(enc_output)) > 0.25).long().to(DEVICE)
            #选取部分decoder_input和一部分
            x_t = choice * dec_batch[:, t] + (1 - choice) * x_t
            x_t = self.model.embeds(x_t)
            # print(x_t.shape)
            final_dist, s_t, ct_e, sum_exp, prev_s = self.model.decoder(x_t, s_t, enc_output, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_exp, prev_s)
            # print(prev_s.shape)
            target = target_batch[:, t]
            # print(target_batch.shape)
            log_probs = torch.log(final_dist + EPS)
            step_loss = F.nll_loss(log_probs, target, reduction="none", ignore_index=self.pad_id)
            total_loss = total_loss + step_loss
            x_t = torch.multinomial(final_dist, 1).squeeze()
            is_oov = (x_t >= VOCAB_SIZE).long()
            x_t = (1 - is_oov) * x_t.detach() + (is_oov) * self.unk_id

        batch_avg_loss = total_loss / dec_lens
        loss = torch.mean(batch_avg_loss)
        return loss

    #强化学习
    def trainRLStep(self,enc_output, enc_hidden, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, article_oovs, type):


        s_t = enc_hidden
        x_t = torch.LongTensor(len(enc_output)).fill_(self.start_id).to(DEVICE)
        prev_s = None
        sum_exp = None
        inds = []
        decoder_padding_mask = []
        log_probs = []
        mask = torch.LongTensor(len(enc_output)).fill_(1).to(DEVICE)

        for t in range(MAX_DEC_STEPS):
            x_t = self.model.embeds(x_t)
            probs, s_t, ct_e, sum_exp, prev_s = self.model.decoder(x_t, s_t, enc_output, enc_padding_mask, ct_e,
                                                                             extra_zeros, enc_batch_extend_vocab,
                                                                             sum_exp, prev_s)
            if type == "sample":
                #根据概率产生sample
                multi_dist = Categorical(probs)
                # print(multi_dist)
                x_t = multi_dist.sample()
                # print(x_t.shape)
                log_prob = multi_dist.log_prob(x_t)
                log_probs.append(log_prob)
            else:
                #greedy sample
                _, x_t = torch.max(probs, dim=1)
            x_t = x_t.detach()
            inds.append(x_t)
            mask_t = torch.zeros(len(enc_output)).to(DEVICE)
            mask_t[mask == 1] = 1
            mask[(mask == 1) + (x_t == self.end_id) == 2] = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (x_t >= VOCAB_SIZE).long()
            #判断是否有超限,若有则用UNK
            x_t = (1 - is_oov) * x_t + (is_oov) * self.unk_id

        inds = torch.stack(inds, dim=1)
        decoder_padding_mask = torch.stack(decoder_padding_mask, dim=1)
        if type == "sample":
            log_probs = torch.stack(log_probs, dim=1)
            #将pad的去除
            log_probs = log_probs * decoder_padding_mask
            lens = torch.sum(decoder_padding_mask, dim=1)
            #对应公式15 logp
            log_probs = torch.sum(log_probs,dim=1) / lens
            # print(log_prob.shape)
        decoded_strs = []
        #将output的id转换为word
        for i in range(len(enc_output)):
            id_list = inds[i].cpu().numpy()
            oovs = article_oovs[i]
            S = data.outputids2words(id_list, self.vocab, oovs)  # Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(data.STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(S) < 2:
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)

        return decoded_strs, log_probs


    def rewardFunction(self, decoded_sents, original_sents):
        rouge = Rouge()
        scores = rouge.get_scores(decoded_sents, original_sents)
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        rouge_l_f1 = (torch.FloatTensor(rouge_l_f1)).to(DEVICE)
        return rouge_l_f1



    #利用beamSearch测试
    def test(self):
        # time.sleep(5)

        batcher = Batcher(TEST_DATA_PATH, self.vocab, mode='test',batch_size=BATCH_SIZE, single_pass=True)
        batch = batcher.next_batch()
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        count = 0
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = self.getEncData(batch)
            with torch.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            with torch.autograd.no_grad():
                pred_ids = self.beamSearch(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab)
                # print(len(pred_ids[0]))
            for i in range(len(pred_ids)):
                # print('t',pred_ids[i])
                decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])
                # print(decoded_words)
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)
            # print(decoded_sents)
            batch = batcher.next_batch()
            scores = rouge.get_scores(decoded_sents, ref_sents, avg=True)
            #统计结果
            if count == 1:
                k0_sum = scores[KEYS[0]]
                k1_sum = scores[KEYS[1]]
                k2_sum = scores[KEYS[2]]

            if count > 1:
                k0_sum = dict(Counter(Counter(k0_sum) + Counter(scores[KEYS[0]])))
                k1_sum = dict(Counter(Counter(k1_sum) + Counter(scores[KEYS[1]])))
                k2_sum = dict(Counter(Counter(k2_sum) + Counter(scores[KEYS[2]])))
            if count == 10:
                break

            count += 1


        # print(scores)
        print(KEYS[0], end=' ')
        for k in k0_sum:
            print(k,k0_sum[k] / count,end = ' ')
        print('\n')
        print(KEYS[1],end = ' ')
        for k in k1_sum:
            print(k,k1_sum[k] / count,end = ' ')
        print('\n')
        print(KEYS[2], end=' ')
        for k in k2_sum:
            print(k,k2_sum[k] / count,end = ' ')
        print('\n')






    def train(self):
        iter = 1
        count = 0
        total_loss = 0
        total_reward = 0
        while iter <= MAX_ITERATIONS:
            batch = self.batcher.next_batch()

            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = self.getEncData(batch)

            enc_batch = self.model.embeds(enc_batch)  # Get embeddings for encoder input
            enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)
            # print(enc_out.shape)
            #将enc_batch传入
            mle_loss = self.trainMLEStep(batch)
            sample_sents, RL_log_probs = self.trainRLStep(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs,"sample")
            with torch.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.trainRLStep(enc_out, enc_hidden, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, batch.art_oovs, "greedy")

            sample_reward = self.rewardFunction(sample_sents, batch.original_abstracts)
            baseline_reward = self.rewardFunction(greedy_sents, batch.original_abstracts)
            #公式15
            rl_loss = -(sample_reward - baseline_reward) * RL_log_probs
            rl_loss = torch.mean(rl_loss)

            batch_reward = torch.mean(sample_reward).item()
            loss = LAMBDA * mle_loss + (1 - LAMBDA) * rl_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            #计算批loss值
            total_loss += loss.item()
            #计算批reward值
            total_reward += batch_reward
            count += 1
            iter += 1

            if iter % PRINT_PER_ITER == 0:


                loss_avg = total_loss / count
                reward_avg = total_reward / count
                total_loss = 0
                print("iter:", iter, "loss:", "%.3f" % loss_avg ,"reward:", "%.3f" % reward_avg)
                count = 0


            if iter % TEST_PER_ITER == 0:
                self.test()
Пример #23
0
class Train(object):
    def __init__(self, opt, vocab, logger, writer, train_num):
        self.vocab = vocab
        self.train_batcher = Batcher(config.train_data_path,
                                     self.vocab,
                                     mode='train',
                                     batch_size=config.batch_size,
                                     single_pass=False)
        self.test_batcher = Batcher(config.test_data_path,
                                    self.vocab,
                                    mode='eval',
                                    batch_size=config.batch_size,
                                    single_pass=True)
        self.opt = opt
        self.start_id = self.vocab.word2id(data.START_DECODING)
        self.end_id = self.vocab.word2id(data.STOP_DECODING)
        self.pad_id = self.vocab.word2id(data.PAD_TOKEN)
        self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        self.logger = logger
        self.writer = writer
        self.train_num = train_num
        time.sleep(5)

    def save_model(self, iter, loss, r_loss):
        if not os.path.exists(config.save_model_path):
            os.makedirs(config.save_model_path)
        file_path = "/%07d_%.2f_%.2f.tar" % (iter, loss, r_loss)
        save_path = config.save_model_path + '/%s' % (self.opt.word_emb_type)
        if not os.path.isdir(save_path): os.mkdir(save_path)
        save_path = save_path + file_path
        T.save(
            {
                "iter": iter + 1,
                "model_dict": self.model.state_dict(),
                "trainer_dict": self.trainer.state_dict()
            }, save_path)
        return file_path

    def setup_train(self):
        self.model = Model(self.opt.pre_train_emb, self.opt.word_emb_type,
                           self.vocab)
        self.logger.info(str(self.model))
        self.model = get_cuda(self.model)
        device = T.device(
            "cuda" if T.cuda.is_available() else "cpu")  # PyTorch v0.4.0
        if self.opt.multi_device:
            if T.cuda.device_count() > 1:
                #                 print("Let's use", T.cuda.device_count(), "GPUs!")
                self.logger.info("Let's use " + str(T.cuda.device_count()) +
                                 " GPUs!")
                self.model = nn.DataParallel(
                    self.model, list(range(T.cuda.device_count()))).cuda()

        if isinstance(self.model, nn.DataParallel):
            self.model = self.model.module
        self.model.to(device)
        #         self.model.eval()

        self.trainer = T.optim.Adam(self.model.parameters(), lr=config.lr)
        start_iter = 0
        if self.opt.load_model is not None:
            load_model_path = config.save_model_path + self.opt.load_model
            print(load_model_path)
            #             print('xxxx')
            checkpoint = T.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.load_state_dict(checkpoint["model_dict"])
            self.trainer.load_state_dict(checkpoint["trainer_dict"])
            #             print("Loaded model at " + load_model_path)
            self.logger.info("Loaded model at " + load_model_path)
        if self.opt.new_lr is not None:
            self.trainer = T.optim.Adam(self.model.parameters(),
                                        lr=self.opt.new_lr)
        return start_iter

    def train_batch_MLE(self, enc_out, enc_hidden, enc_padding_mask, ct_e,
                        extra_zeros, enc_batch_extend_vocab, enc_key_batch,
                        enc_key_lens, batch):
        ''' Calculate Negative Log Likelihood Loss for the given batch. In order to reduce exposure bias,
                pass the previous generated token as input with a probability of 0.25 instead of ground truth label
        Args:
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param batch: batch object
        '''
        dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(
            batch)  # Get input and target batchs for training decoder
        step_losses = []
        s_t = (enc_hidden[0], enc_hidden[1])  # Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(
            self.start_id))  # Input to the decoder
        prev_s = None  # Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None  # Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        try:
            #             print('-----------------')
            for t in range(min(max_dec_len, config.max_dec_steps)):
                #                 print('decoder time step %s'%t)
                use_gound_truth = get_cuda(
                    (T.rand(len(enc_out)) > config.gound_truth_prob)
                ).long(
                )  # Probabilities indicating whether to use ground truth labels instead of previous decoded tokens
                x_t = use_gound_truth * dec_batch[:, t] + (
                    1 - use_gound_truth
                ) * x_t  # Select decoder input based on use_ground_truth probabilities
                #                 print('prev dec word',x_t)
                #                 words = [self.vocab.id2word(id) for id in x_t]
                #                 print('words',words)
                x_t = self.model.embeds(x_t)
                final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
                    x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros,
                    enc_batch_extend_vocab, sum_temporal_srcs, prev_s,
                    enc_key_batch, enc_key_lens)
                # print('final_dist',final_dist,final_dist.shape);
                #                 print('final_dist',final_dist.shape);
                target = target_batch[:, t]
                log_probs = T.log(final_dist + config.eps)
                step_loss = F.nll_loss(log_probs,
                                       target,
                                       reduction="none",
                                       ignore_index=self.pad_id)
                step_losses.append(step_loss)
                x_t = T.multinomial(final_dist, 1).squeeze(
                )  # Sample words from final distribution which can be used as input in next time step
                #                 print(config.vocab_size)
                #                 print('x_t',x_t.shape);
                is_oov = (x_t >= config.vocab_size).long(
                )  # Mask indicating whether sampled word is OOV
                x_t = (1 - is_oov) * x_t.detach() + (
                    is_oov) * self.unk_id  # Replace OOVs with [UNK] token
                #                 print('x_t_22',x_t.shape)
                #                 print('finish inner loop'); print('-------------------------------------------------\n')

        except KeyboardInterrupt as e:
            self.logger.error('xxxxxxxxxxx')
            traceback = sys.exc_info()[2]
            self.logger.error(sys.exc_info())
            self.logger.error(traceback.tb_lineno)
            self.logger.error(e)
            #             self.logger.error(final_dist)
            self.logger.error('xxxxxxxxxxx')
            #             print(step_loss)

        losses = T.sum(
            T.stack(step_losses, 1), 1
        )  # unnormalized losses for each example in the batch; (batch_size)
        batch_avg_loss = losses / dec_lens  # Normalized losses; (batch_size)
        mle_loss = T.mean(batch_avg_loss)  # Average batch loss
        return mle_loss

    def train_batch_RL(self, enc_out, enc_hidden, enc_padding_mask, ct_e,
                       extra_zeros, enc_batch_extend_vocab, review_oovs,
                       enc_key_batch, enc_key_lens, greedy):
        '''Generate sentences from decoder entirely using sampled tokens as input. These sentences are used for ROUGE evaluation
        Args
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param review_oovs: Batch containing list of OOVs in each example
        :param greedy: If true, performs greedy based sampling, else performs multinomial sampling
        Returns:
        :decoded_strs: List of decoded sentences
        :log_probs: Log probabilities of sampled words
        '''
        s_t = enc_hidden  # Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(
            self.start_id))  # Input to the decoder
        prev_s = None  # Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None  # Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        inds = []  # Stores sampled indices for each time step
        decoder_padding_mask = []  # Stores padding masks of generated samples
        log_probs = []  # Stores log probabilites of generated samples
        mask = get_cuda(
            T.LongTensor(len(enc_out)).fill_(1)
        )  # Values that indicate whether [STOP] token has already been encountered; 1 => Not encountered, 0 otherwise

        for t in range(config.max_dec_steps):
            x_t = self.model.embeds(x_t)
            probs, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
                x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros,
                enc_batch_extend_vocab, sum_temporal_srcs, prev_s,
                enc_key_batch, enc_key_lens)

            if greedy is False:
                multi_dist = Categorical(probs)
                x_t = multi_dist.sample()  # perform multinomial sampling
                log_prob = multi_dist.log_prob(x_t)
                log_probs.append(log_prob)
            else:
                _, x_t = T.max(probs, dim=1)  # perform greedy sampling
            x_t = x_t.detach()
            inds.append(x_t)
            mask_t = get_cuda(T.zeros(
                len(enc_out)))  # Padding mask of batch for current time step
            mask_t[
                mask ==
                1] = 1  # If [STOP] is not encountered till previous time step, mask_t = 1 else mask_t = 0
            mask[
                (mask == 1) + (x_t == self.end_id) ==
                2] = 0  # If [STOP] is not encountered till previous time step and current word is [STOP], make mask = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (x_t >= config.vocab_size
                      ).long()  # Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t + (
                is_oov) * self.unk_id  # Replace OOVs with [UNK] token

        inds = T.stack(inds, dim=1)
        decoder_padding_mask = T.stack(decoder_padding_mask, dim=1)
        if greedy is False:  # If multinomial based sampling, compute log probabilites of sampled words
            log_probs = T.stack(log_probs, dim=1)
            log_probs = log_probs * decoder_padding_mask  # Not considering sampled words with padding mask = 0
            lens = T.sum(decoder_padding_mask,
                         dim=1)  # Length of sampled sentence
            log_probs = T.sum(
                log_probs, dim=1
            ) / lens  # (bs,)                                     #compute normalizied log probability of a sentence
        decoded_strs = []
        for i in range(len(enc_out)):
            id_list = inds[i].cpu().numpy()
            oovs = review_oovs[i]
            S = data.outputids2words(
                id_list, self.vocab,
                oovs)  # Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(data.STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(
                    S
            ) < 2:  # If length of sentence is less than 2 words, replace it with "xxx"; Avoids setences like "." which throws error while calculating ROUGE
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)
            #         print(log_probs)
        return decoded_strs, log_probs

    def reward_function(self, decoded_sents, original_sents):
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            #             print("Rouge failed for multi sentence evaluation.. Finding exact pair")
            self.logger.info(
                "Rouge failed for multi sentence evaluation.. Finding exact pair"
            )
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i],
                                             original_sents[i])
                except Exception:
                    #                     print("Error occured at:")
                    #                     print("decoded_sents:", decoded_sents[i])
                    #                     print("original_sents:", original_sents[i])
                    self.logger.info("Error occured at:")
                    self.logger.info("decoded_sents:", decoded_sents[i])
                    self.logger.info("original_sents:", original_sents[i])
                    score = [{"rouge-l": {"f": 0.0}}]
                scores.append(score[0])
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        avg_rouge_l_f1 = sum(rouge_l_f1) / len(rouge_l_f1)
        rouge_l_f1 = get_cuda(T.FloatTensor(rouge_l_f1))
        return rouge_l_f1, scores, avg_rouge_l_f1

    # def write_to_file(self, decoded, max, original, sample_r, baseline_r, iter):
    #     with open("temp.txt", "w") as f:
    #         f.write("iter:"+str(iter)+"\n")
    #         for i in range(len(original)):
    #             f.write("dec: "+decoded[i]+"\n")
    #             f.write("max: "+max[i]+"\n")
    #             f.write("org: "+original[i]+"\n")
    #             f.write("Sample_R: %.4f, Baseline_R: %.4f\n\n"%(sample_r[i].item(), baseline_r[i].item()))

    def train_one_batch(self, batch, test_batch, iter):
        ans_list, batch_scores = None, None
        # Train
        #         enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(batch)

        enc_batch, enc_lens, enc_padding_mask, \
        enc_key_batch, enc_key_lens, _,\
        enc_batch_extend_vocab, extra_zeros, context = get_enc_data(batch)
        #         print('enc_lens',enc_lens.shape)
        enc_batch = self.model.embeds(
            enc_batch)  # Get embeddings for encoder input
        enc_key_batch = self.model.embeds(
            enc_key_batch)  # Get key embeddings for encoder input
        enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

        # Test
        #         enc_batch2, enc_lens2, enc_padding_mask2, enc_batch_extend_vocab2, extra_zeros2, context2 = get_enc_data(test_batch)
        enc_batch2, enc_lens2, enc_padding_mask2, \
        enc_key_batch2, enc_key_lens2, _,\
        enc_batch_extend_vocab2, extra_zeros2, context2 = get_enc_data(test_batch)

        # 停止回溯計算歷史或不納入特定一個 Tensor 物件於計算圖
        with T.autograd.no_grad():
            enc_batch2 = self.model.embeds(enc_batch2)
            enc_key_batch2 = self.model.embeds(enc_key_batch2)
            enc_out2, enc_hidden2 = self.model.encoder(enc_batch2, enc_lens2)
        # -------------------------------Summarization-----------------------
        if self.opt.train_mle == True:  # perform MLE training
            mle_loss = self.train_batch_MLE(enc_out, enc_hidden,
                                            enc_padding_mask, context,
                                            extra_zeros,
                                            enc_batch_extend_vocab,
                                            enc_key_batch, enc_key_lens, batch)
            mle_loss_2 = self.train_batch_MLE(enc_out2, enc_hidden2,
                                              enc_padding_mask2, context2,
                                              extra_zeros2,
                                              enc_batch_extend_vocab2,
                                              enc_key_batch2, enc_key_lens2,
                                              test_batch)
        else:
            mle_loss = get_cuda(T.FloatTensor([0]))
            mle_loss_2 = get_cuda(T.FloatTensor([0]))
        # --------------RL training-----------------------------------------------------
        if self.opt.train_rl == True:  # perform reinforcement learning training
            # multinomial sampling
            sample_sents, RL_log_probs = self.train_batch_RL(
                enc_out,
                enc_hidden,
                enc_padding_mask,
                context,
                extra_zeros,
                enc_batch_extend_vocab,
                batch.rev_oovs,
                enc_key_batch,
                enc_key_lens,
                greedy=False)
            sample_sents2, RL_log_probs2 = self.train_batch_RL(
                enc_out2,
                enc_hidden2,
                enc_padding_mask2,
                context2,
                extra_zeros2,
                enc_batch_extend_vocab2,
                test_batch.rev_oovs,
                enc_key_batch2,
                enc_key_lens2,
                greedy=False)
            with T.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.train_batch_RL(enc_out,
                                                      enc_hidden,
                                                      enc_padding_mask,
                                                      context,
                                                      extra_zeros,
                                                      enc_batch_extend_vocab,
                                                      batch.rev_oovs,
                                                      enc_key_batch,
                                                      enc_key_lens,
                                                      greedy=True)

            sample_reward, _, _ = self.reward_function(sample_sents,
                                                       batch.original_summarys)
            baseline_reward, _, _ = self.reward_function(
                greedy_sents, batch.original_summarys)
            # if iter%200 == 0:
            #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
            rl_loss = -(
                sample_reward - baseline_reward
            ) * RL_log_probs  # Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
            rl_loss = T.mean(rl_loss)

            batch_reward = T.mean(sample_reward).item()
            self.writer.add_scalar('RL_Train/rl_loss', rl_loss, iter)
        else:
            rl_loss = get_cuda(T.FloatTensor([0]))
            batch_reward = 0
        # ------------------------------------------------------------------------------------
        #         if opt.train_mle == True:
        self.trainer.zero_grad()
        (self.opt.mle_weight * mle_loss +
         self.opt.rl_weight * rl_loss).backward()
        self.trainer.step()
        #-----------------------Summarization----------------------------------------------------
        if iter % 1000 == 0:
            with T.autograd.no_grad():
                train_rouge_l_f = self.calc_avg_rouge_result(
                    iter, batch, 'Train', enc_hidden, enc_out,
                    enc_padding_mask, context, extra_zeros,
                    enc_batch_extend_vocab, enc_key_batch, enc_key_lens)
                test_rouge_l_f = self.calc_avg_rouge_result(
                    iter, test_batch, 'Test', enc_hidden2, enc_out2,
                    enc_padding_mask2, context2, extra_zeros2,
                    enc_batch_extend_vocab2, enc_key_batch2, enc_key_lens2)
                self.writer.add_scalars(
                    'Compare/rouge-l-f', {
                        'train_rouge_l_f': train_rouge_l_f,
                        'test_rouge_l_f': test_rouge_l_f
                    }, iter)
                self.logger.info(
                    'iter: %s train_rouge_l_f: %.3f test_rouge_l_f: %.3f \n' %
                    (iter, train_rouge_l_f, test_rouge_l_f))

        return mle_loss.item(), mle_loss_2.item(), batch_reward

    def calc_avg_rouge_result(self, iter, batch, mode, enc_hidden, enc_out,
                              enc_padding_mask, context, extra_zeros,
                              enc_batch_extend_vocab, enc_key_batch,
                              enc_key_lens):
        pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, context,
                               extra_zeros, enc_batch_extend_vocab,
                               enc_key_batch, enc_key_lens, self.model,
                               self.start_id, self.end_id, self.unk_id)

        article_sents, decoded_sents, keywords_list, \
        ref_sents, long_seq_index = prepare_result(data, self.vocab, batch, mode, pred_ids)

        rouge_l = write_rouge(self.writer,iter,mode,article_sents, decoded_sents, \
                    keywords_list, ref_sents, long_seq_index)

        write_bleu(self.writer,iter, mode, article_sents, decoded_sents, \
                   keywords_list, ref_sents, long_seq_index)

        write_group(self.writer,iter,mode,article_sents, decoded_sents,\
                    keywords_list, ref_sents, long_seq_index)

        return rouge_l

    # def get_best_res_score(self, results, scores):
    #     max_score = float(0)
    #     _id = 0
    #     for idx in range(len(results)):
    #         re_matchData = re.compile(r'\-?\d{1,10}\.?\d{1,10}')
    #         data = re.findall(re_matchData, str(scores[idx]))
    #         score = sum([float(d) for d in data])
    #         if score > max_score:
    #             _id = idx
    #     return results[_id], scores[_id]

    def get_lr(self):
        for param_group in self.trainer.param_groups:
            return param_group['lr']

    def get_weight_decay(self):
        for param_group in self.trainer.param_groups:
            #             print(param_group)
            return param_group['weight_decay']

    def trainIters(self):
        final_file_path = None
        iter = self.setup_train()
        epoch = 0
        count = test_mle_total = train_mle_total = r_total = 0
        self.logger.info(u'------Training START--------')
        test_batch = self.test_batcher.next_batch()
        #         while iter <= config.max_iterations:
        while epoch <= config.max_epochs:
            train_batch = self.train_batcher.next_batch()
            try:
                train_mle_loss, test_mle_loss, r = self.train_one_batch(
                    train_batch, test_batch, iter)

                self.writer.add_scalars(
                    'Compare/mle_loss', {
                        'train_mle_loss': train_mle_loss,
                        'test_mle_loss': test_mle_loss
                    }, iter)

            # break
            except KeyboardInterrupt:
                self.logger.info(
                    "-------------------Keyboard Interrupt------------------")
                exit(0)
            except KeyError as e:
                self.logger.info(
                    "-------------------Ignore error------------------\n%s\n" %
                    e)
                print("Please load final_file_path : %s" % final_file_path)
                traceback = sys.exc_info()[2]
                print(sys.exc_info())
                print(traceback.tb_lineno)
                print(e)
                break
            # if opt.train_mle == False: break
            train_mle_total += train_mle_loss
            r_total += r
            test_mle_total += test_mle_loss
            count += 1
            iter += 1

            if iter % 1000 == 0:
                train_mle_avg = train_mle_total / count
                r_avg = r_total / count
                test_mle_avg = test_mle_total / count
                epoch = int((iter * config.batch_size) / self.train_num) + 1
                # self.logger.info('epoch: %s iter: %s train_mle_loss: %.3f test_mle_loss: %.3f reward: %.3f \n' % (epoch, iter, train_mle_avg, test_mle_avg, r_avg))

                count = test_mle_total = train_mle_total = r_total = 0
                self.writer.add_scalar('RL_Train/r_avg', r_avg, iter)

                self.writer.add_scalars('Compare/mle_avg_loss', {
                    'train_mle_avg': train_mle_avg,
                    'test_mle_avg': test_mle_avg
                }, iter)
            # break
            if iter % 5000 == 0:
                final_file_path = self.save_model(iter, test_mle_avg, r_avg)
    def test(self):
        # time.sleep(5)

        batcher = Batcher(TEST_DATA_PATH, self.vocab, mode='test',batch_size=BATCH_SIZE, single_pass=True)
        batch = batcher.next_batch()
        decoded_sents = []
        ref_sents = []
        article_sents = []
        rouge = Rouge()
        count = 0
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, ct_e = self.getEncData(batch)
            with torch.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            with torch.autograd.no_grad():
                pred_ids = self.beamSearch(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab)
                # print(len(pred_ids[0]))
            for i in range(len(pred_ids)):
                # print('t',pred_ids[i])
                decoded_words = data.outputids2words(pred_ids[i], self.vocab, batch.art_oovs[i])
                # print(decoded_words)
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                abstract = batch.original_abstracts[i]
                article = batch.original_articles[i]
                ref_sents.append(abstract)
                article_sents.append(article)
            # print(decoded_sents)
            batch = batcher.next_batch()
            scores = rouge.get_scores(decoded_sents, ref_sents, avg=True)
            #统计结果
            if count == 1:
                k0_sum = scores[KEYS[0]]
                k1_sum = scores[KEYS[1]]
                k2_sum = scores[KEYS[2]]

            if count > 1:
                k0_sum = dict(Counter(Counter(k0_sum) + Counter(scores[KEYS[0]])))
                k1_sum = dict(Counter(Counter(k1_sum) + Counter(scores[KEYS[1]])))
                k2_sum = dict(Counter(Counter(k2_sum) + Counter(scores[KEYS[2]])))
            if count == 10:
                break

            count += 1


        # print(scores)
        print(KEYS[0], end=' ')
        for k in k0_sum:
            print(k,k0_sum[k] / count,end = ' ')
        print('\n')
        print(KEYS[1],end = ' ')
        for k in k1_sum:
            print(k,k1_sum[k] / count,end = ' ')
        print('\n')
        print(KEYS[2], end=' ')
        for k in k2_sum:
            print(k,k2_sum[k] / count,end = ' ')
        print('\n')
Пример #25
0
class Train(object):
    def __init__(self, opt):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        self.opt = opt
        self.start_id = self.vocab.word2id(data.START_DECODING)
        self.end_id = self.vocab.word2id(data.STOP_DECODING)
        self.pad_id = self.vocab.word2id(data.PAD_TOKEN)
        self.unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        time.sleep(5)

    def save_model(self, iter):
        save_path = config.save_model_path + "/%07d.tar" % iter
        T.save(
            {
                "iter": iter + 1,
                "model_dict": self.model.state_dict(),
                "trainer_dict": self.trainer.state_dict()
            }, save_path)

    def setup_train(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        self.trainer = T.optim.Adam(self.model.parameters(), lr=config.lr)
        start_iter = 0
        if self.opt.load_model is not None:
            load_model_path = os.path.join(config.save_model_path,
                                           self.opt.load_model)
            checkpoint = T.load(load_model_path)
            start_iter = checkpoint["iter"]
            self.model.load_state_dict(checkpoint["model_dict"])
            self.trainer.load_state_dict(checkpoint["trainer_dict"])
            print("Loaded model at " + load_model_path)
        if self.opt.new_lr is not None:
            self.trainer = T.optim.Adam(self.model.parameters(),
                                        lr=self.opt.new_lr)
        return start_iter

    def train_batch_MLE(self, enc_out, enc_hidden, enc_padding_mask, ct_e,
                        extra_zeros, enc_batch_extend_vocab, batch):
        ''' Calculate Negative Log Likelihood Loss for the given batch. In order to reduce exposure bias,
                pass the previous generated token as input with a probability of 0.25 instead of ground truth label
        Args:
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param batch: batch object
        '''
        dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(
            batch)  #Get input and target batchs for training decoder
        step_losses = []
        s_t = (enc_hidden[0], enc_hidden[1])  #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(
            self.start_id))  #Input to the decoder
        prev_s = None  #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None  #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        for t in range(min(max_dec_len, config.max_dec_steps)):
            use_gound_truth = get_cuda((T.rand(len(enc_out)) > 0.25)).long(
            )  #Probabilities indicating whether to use ground truth labels instead of previous decoded tokens
            x_t = use_gound_truth * dec_batch[:, t] + (
                1 - use_gound_truth
            ) * x_t  #Select decoder input based on use_ground_truth probabilities
            x_t = self.model.embeds(x_t)
            final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
                x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros,
                enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            target = target_batch[:, t]
            log_probs = T.log(final_dist + config.eps)
            step_loss = F.nll_loss(log_probs,
                                   target,
                                   reduction="none",
                                   ignore_index=self.pad_id)
            step_losses.append(step_loss)
            x_t = T.multinomial(final_dist, 1).squeeze(
            )  #Sample words from final distribution which can be used as input in next time step
            is_oov = (x_t >= config.vocab_size
                      ).long()  #Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t.detach() + (
                is_oov) * self.unk_id  #Replace OOVs with [UNK] token

        losses = T.sum(
            T.stack(step_losses, 1), 1
        )  #unnormalized losses for each example in the batch; (batch_size)
        batch_avg_loss = losses / dec_lens  #Normalized losses; (batch_size)
        mle_loss = T.mean(batch_avg_loss)  #Average batch loss
        return mle_loss

    def train_batch_RL(self, enc_out, enc_hidden, enc_padding_mask, ct_e,
                       extra_zeros, enc_batch_extend_vocab, article_oovs,
                       greedy):
        '''Generate sentences from decoder entirely using sampled tokens as input. These sentences are used for ROUGE evaluation
        Args
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param article_oovs: Batch containing list of OOVs in each example
        :param greedy: If true, performs greedy based sampling, else performs multinomial sampling
        Returns:
        :decoded_strs: List of decoded sentences
        :log_probs: Log probabilities of sampled words
        '''
        s_t = enc_hidden  #Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(
            self.start_id))  #Input to the decoder
        prev_s = None  #Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None  #Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        inds = []  #Stores sampled indices for each time step
        decoder_padding_mask = []  #Stores padding masks of generated samples
        log_probs = []  #Stores log probabilites of generated samples
        mask = get_cuda(
            T.LongTensor(len(enc_out)).fill_(1)
        )  #Values that indicate whether [STOP] token has already been encountered; 1 => Not encountered, 0 otherwise

        for t in range(config.max_dec_steps):
            x_t = self.model.embeds(x_t)
            probs, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
                x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros,
                enc_batch_extend_vocab, sum_temporal_srcs, prev_s)
            if greedy is False:
                multi_dist = Categorical(probs)
                x_t = multi_dist.sample()  #perform multinomial sampling
                log_prob = multi_dist.log_prob(x_t)
                log_probs.append(log_prob)
            else:
                _, x_t = T.max(probs, dim=1)  #perform greedy sampling
            x_t = x_t.detach()
            inds.append(x_t)
            mask_t = get_cuda(T.zeros(
                len(enc_out)))  #Padding mask of batch for current time step
            mask_t[
                mask ==
                1] = 1  #If [STOP] is not encountered till previous time step, mask_t = 1 else mask_t = 0
            mask[
                (mask == 1) + (x_t == self.end_id) ==
                2] = 0  #If [STOP] is not encountered till previous time step and current word is [STOP], make mask = 0
            decoder_padding_mask.append(mask_t)
            is_oov = (x_t >= config.vocab_size
                      ).long()  #Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t + (
                is_oov) * self.unk_id  #Replace OOVs with [UNK] token

        inds = T.stack(inds, dim=1)
        decoder_padding_mask = T.stack(decoder_padding_mask, dim=1)
        if greedy is False:  #If multinomial based sampling, compute log probabilites of sampled words
            log_probs = T.stack(log_probs, dim=1)
            log_probs = log_probs * decoder_padding_mask  #Not considering sampled words with padding mask = 0
            lens = T.sum(decoder_padding_mask,
                         dim=1)  #Length of sampled sentence
            log_probs = T.sum(
                log_probs, dim=1
            ) / lens  # (bs,)                                     #compute normalizied log probability of a sentence
        decoded_strs = []
        for i in range(len(enc_out)):
            id_list = inds[i].cpu().numpy()
            oovs = article_oovs[i]
            S = data.outputids2words(
                id_list, self.vocab,
                oovs)  # Generate sentence corresponding to sampled words
            try:
                end_idx = S.index(data.STOP_DECODING)
                S = S[:end_idx]
            except ValueError:
                S = S
            if len(
                    S
            ) < 2:  #If length of sentence is less than 2 words, replace it with "xxx"; Avoids setences like "." which throws error while calculating ROUGE
                S = ["xxx"]
            S = " ".join(S)
            decoded_strs.append(S)

        return decoded_strs, log_probs

    def reward_function(self, decoded_sents, original_sents):
        rouge = Rouge()
        try:
            scores = rouge.get_scores(decoded_sents, original_sents)
        except Exception:
            print(
                "Rouge failed for multi sentence evaluation.. Finding exact pair"
            )
            scores = []
            for i in range(len(decoded_sents)):
                try:
                    score = rouge.get_scores(decoded_sents[i],
                                             original_sents[i])
                except Exception:
                    print("Error occured at:")
                    print("decoded_sents:", decoded_sents[i])
                    print("original_sents:", original_sents[i])
                    score = [{"rouge-l": {"f": 0.0}}]
                scores.append(score[0])
        rouge_l_f1 = [score["rouge-l"]["f"] for score in scores]
        rouge_l_f1 = get_cuda(T.FloatTensor(rouge_l_f1))
        return rouge_l_f1

    # def write_to_file(self, decoded, max, original, sample_r, baseline_r, iter):
    #     with open("temp.txt", "w") as f:
    #         f.write("iter:"+str(iter)+"\n")
    #         for i in range(len(original)):
    #             f.write("dec: "+decoded[i]+"\n")
    #             f.write("max: "+max[i]+"\n")
    #             f.write("org: "+original[i]+"\n")
    #             f.write("Sample_R: %.4f, Baseline_R: %.4f\n\n"%(sample_r[i].item(), baseline_r[i].item()))

    def train_one_batch(self, batch, iter):
        enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(
            batch)

        enc_batch = self.model.embeds(
            enc_batch)  #Get embeddings for encoder input
        enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

        # -------------------------------Summarization-----------------------
        if self.opt.train_mle == "yes":  #perform MLE training
            mle_loss = self.train_batch_MLE(enc_out, enc_hidden,
                                            enc_padding_mask, context,
                                            extra_zeros,
                                            enc_batch_extend_vocab, batch)
        else:
            mle_loss = get_cuda(T.FloatTensor([0]))
        # --------------RL training-----------------------------------------------------
        if self.opt.train_rl == "yes":  #perform reinforcement learning training
            # multinomial sampling
            sample_sents, RL_log_probs = self.train_batch_RL(
                enc_out,
                enc_hidden,
                enc_padding_mask,
                context,
                extra_zeros,
                enc_batch_extend_vocab,
                batch.art_oovs,
                greedy=False)
            with T.autograd.no_grad():
                # greedy sampling
                greedy_sents, _ = self.train_batch_RL(enc_out,
                                                      enc_hidden,
                                                      enc_padding_mask,
                                                      context,
                                                      extra_zeros,
                                                      enc_batch_extend_vocab,
                                                      batch.art_oovs,
                                                      greedy=True)

            sample_reward = self.reward_function(sample_sents,
                                                 batch.original_abstracts)
            baseline_reward = self.reward_function(greedy_sents,
                                                   batch.original_abstracts)
            # if iter%200 == 0:
            #     self.write_to_file(sample_sents, greedy_sents, batch.original_abstracts, sample_reward, baseline_reward, iter)
            rl_loss = -(
                sample_reward - baseline_reward
            ) * RL_log_probs  #Self-critic policy gradient training (eq 15 in https://arxiv.org/pdf/1705.04304.pdf)
            rl_loss = T.mean(rl_loss)

            batch_reward = T.mean(sample_reward).item()
        else:
            rl_loss = get_cuda(T.FloatTensor([0]))
            batch_reward = 0

    # ------------------------------------------------------------------------------------
        self.trainer.zero_grad()
        (self.opt.mle_weight * mle_loss +
         self.opt.rl_weight * rl_loss).backward()
        self.trainer.step()
        return mle_loss.item(), batch_reward

    def trainIters(self):
        iter = self.setup_train()
        count = mle_total = r_total = 0
        while iter <= config.max_iterations:
            batch = self.batcher.next_batch()
            try:
                mle_loss, r = self.train_one_batch(batch, iter)
            except KeyboardInterrupt:
                print(
                    "-------------------Keyboard Interrupt------------------")
                exit(0)

            mle_total += mle_loss
            r_total += r
            count += 1
            iter += 1

            if iter % 50 == 0:
                mle_avg = mle_total / count
                r_avg = r_total / count
                logger.info("iter:" + str(iter) + "  mle_loss:" +
                            "%.3f" % mle_avg + "  reward:" + "%.4f" % r_avg)
                count = mle_total = r_total = 0

            if iter % 5000 == 0:
                self.save_model(iter)
Пример #26
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.concept_vocab = Concept_vocab(config.concept_vocab_path,
                                           config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               self.concept_vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        self.ds_batcher = Batcher(config.train_ds_data_path,
                                  self.vocab,
                                  self.concept_vocab,
                                  mode='train',
                                  batch_size=500,
                                  single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = AdagradCustom(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def calc_Rouge_1(self, sub, string):
        new_sub = [str(x) for x in sub]
        new_sub.insert(0, '"')
        new_sub.append('"')
        token_c = ' '.join(new_sub)
        summary = [[token_c]]
        new_string = [str(x) for x in string]
        new_string.insert(0, '"')
        new_string.append('"')
        token_r = ' '.join(new_string)
        reference = [[[token_r]]]

        rouge = Pythonrouge(summary_file_exist=False,
                            summary=summary,
                            reference=reference,
                            n_gram=2,
                            ROUGE_SU4=False,
                            ROUGE_L=False,
                            recall_only=True,
                            stemming=True,
                            stopwords=True,
                            word_level=True,
                            length_limit=True,
                            length=30,
                            use_cf=False,
                            cf=95,
                            scoring_formula='average',
                            resampling=False,
                            samples=10,
                            favor=True,
                            p=0.5)
        score = rouge.calc_score()
        return score['ROUGE-1']

    def calc_Rouge_2_recall(self, sub, string):
        token_c = sub
        token_r = string
        model = []
        ref = []
        if len(string) == 1 or len(string) == 1:
            score = 0.0
        else:
            i = 1
            while i < len(string):
                ref.append(str(token_r[i - 1]) + str(token_r[i]))
                i += 1
            i = 1
            while i < len(sub):
                model.append(str(token_c[i - 1]) + str(token_c[i]))
                i += 1
            sam = 0
            i = 0
            for i in range(len(ref)):
                for j in range(len(model)):
                    if ref[i] == model[j]:
                        sam += 1
                        model[j] = '-1'
                        break

            score = sam / float(len(ref))

        return score

    def calc_Rouge_L(self, sub, string):
        beta = 1.
        token_c = sub
        token_r = string
        if (len(string) < len(sub)):
            sub, string = string, sub
        lengths = [[0 for i in range(0,
                                     len(sub) + 1)]
                   for j in range(0,
                                  len(string) + 1)]
        for j in range(1, len(sub) + 1):
            for i in range(1, len(string) + 1):
                if (string[i - 1] == sub[j - 1]):
                    lengths[i][j] = lengths[i - 1][j - 1] + 1
                else:
                    lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
        lcs = lengths[len(string)][len(sub)]

        prec = lcs / float(len(token_c))
        rec = lcs / float(len(token_r))

        if (prec != 0 and rec != 0):
            score = ((1 + beta**2) * prec * rec) / float(rec + beta**2 * prec)
        else:
            score = 0.0
        return rec

    def calc_kl(self, dec, enc):
        kl = 0.
        dec = np.exp(dec)
        enc = np.exp(enc)
        all_dec = np.sum(dec)
        all_enc = np.sum(enc)
        for d, c in zip(dec, enc):
            d = d / all_dec
            c = c / all_enc
            kl = kl + c * np.log(c / d)
        return kl

    def calc_euc(self, dec, enc):
        euc = 0.
        for d, c in zip(dec, enc):
            euc = euc + np.sqrt(np.square(d - c))
        #print euc
        return euc

    def ds_loss(self, enc_batch_ds_emb, enc_padding_mask_ds, dec_batch_emb,
                dec_padding_mask):
        b1, t_k1, emb1 = list(enc_batch_ds_emb.size())
        b2, t_k2, emb2 = list(dec_batch_emb.size())
        enc_padding_mask_ds = enc_padding_mask_ds.unsqueeze(2).expand(
            b1, t_k1, emb1).contiguous()
        dec_padding_mask = dec_padding_mask.unsqueeze(2).expand(
            b2, t_k2, emb2).contiguous()
        enc_batch_ds_emb = enc_batch_ds_emb * enc_padding_mask_ds
        dec_batch_emb = dec_batch_emb * dec_padding_mask
        enc_batch_ds_emb = torch.sum(enc_batch_ds_emb, 1)
        dec_batch_emb = torch.sum(dec_batch_emb, 1)
        dec_title = dec_batch_emb.tolist()
        enc_article = enc_batch_ds_emb.tolist()
        dec_title_len = len(dec_title)
        enc_article_len = len(enc_article)
        dsloss = 0.
        for dec in dec_title:
            for enc in enc_article:
                dsloss = dsloss + self.calc_kl(dec, enc)
        dsloss = dsloss / float(dec_title_len * enc_article_len)
        print(dsloss)
        return dsloss

    def train_one_batch(self, batch, steps, batch_ds):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, enc_batch_concept_extend_vocab, concept_p, position, concept_mask, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)
        enc_batch_ds, enc_padding_mask_ds, enc_lens_ds, _, _, _, _, _, _, _, _ = \
            get_input_from_batch(batch_ds, use_cuda)

        self.optimizer.zero_grad()
        encoder_outputs, encoder_hidden, max_encoder_output, enc_batch_ds_emb, dec_batch_emb = self.model.encoder(
            enc_batch, enc_lens, enc_batch_ds, dec_batch)
        if config.DS_train:
            ds_final_loss = self.ds_loss(enc_batch_ds_emb, enc_padding_mask_ds,
                                         dec_batch_emb, dec_padding_mask)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        s_t_0 = s_t_1
        c_t_0 = c_t_1
        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output
            c_t_0 = c_t_1

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                c_t_1, extra_zeros, enc_batch_extend_vocab,
                enc_batch_concept_extend_vocab, concept_p, position,
                concept_mask, coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        if config.DS_train:
            ds_final_loss = Variable(torch.FloatTensor([ds_final_loss]),
                                     requires_grad=False)
            ds_final_loss = ds_final_loss.cuda()
            loss = (config.pi - ds_final_loss) * torch.mean(batch_avg_loss)
        else:
            loss = torch.mean(batch_avg_loss)
        if steps > config.traintimes:
            scores = []
            sample_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    sample_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    sample_latest_tokens = sample_y[-1]
                    sample_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in sample_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(sample_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                sample_select = torch.multinomial(final_dist, 1).view(-1)
                sample_log_probs = torch.gather(
                    final_dist, 1, sample_select.unsqueeze(1)).squeeze()
                sample_y.append(sample_select.cpu().numpy().tolist())
                sample_step_loss = -torch.log(sample_log_probs + config.eps)
                sample_step_mask = dec_padding_mask[:, di]
                sample_step_loss = sample_step_loss * sample_step_mask
                scores.append(sample_step_loss)
            sample_sum_losses = torch.sum(torch.stack(scores, 1), 1)
            sample_batch_avg_loss = sample_sum_losses / dec_lens_var

            sample_y = np.transpose(sample_y).tolist()

            base_y = []
            s_t_1 = s_t_0
            c_t_1 = c_t_0
            for di in range(min(max_dec_len, config.max_dec_steps)):
                if di == 0:
                    y_t_1 = dec_batch[:, di]
                    base_y.append(y_t_1.cpu().numpy().tolist())
                else:
                    base_latest_tokens = base_y[-1]
                    base_latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                                            for t in base_latest_tokens]

                    y_t_1 = Variable(torch.LongTensor(base_latest_tokens))
                    y_t_1 = y_t_1.cuda()

                final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                    'train', y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
                    c_t_1, extra_zeros, enc_batch_extend_vocab,
                    enc_batch_concept_extend_vocab, concept_p, position,
                    concept_mask, coverage, di)
                base_log_probs, base_ids = torch.topk(final_dist, 1)
                base_y.append(base_ids[:, 0].cpu().numpy().tolist())

            base_y = np.transpose(base_y).tolist()

            refs = dec_batch.cpu().numpy().tolist()
            sample_dec_lens_var = map(int, dec_lens_var.cpu().numpy().tolist())
            sample_rougeL = [
                self.calc_Rouge_L(sample[:reflen],
                                  ref[:reflen]) for sample, ref, reflen in zip(
                                      sample_y, refs, sample_dec_lens_var)
            ]
            base_rougeL = [
                self.calc_Rouge_L(base[:reflen], ref[:reflen])
                for base, ref, reflen in zip(base_y, refs, sample_dec_lens_var)
            ]
            sample_rougeL = Variable(torch.FloatTensor(sample_rougeL),
                                     requires_grad=False)
            base_rougeL = Variable(torch.FloatTensor(base_rougeL),
                                   requires_grad=False)
            sample_rougeL = sample_rougeL.cuda()
            base_rougeL = base_rougeL.cuda()
            word_loss = -sample_batch_avg_loss * (base_rougeL - sample_rougeL)
            reinforce_loss = torch.mean(word_loss)
            loss = (1 - config.rein) * loss + config.rein * reinforce_loss

        loss.backward()

        clip_grad_norm(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()

        return loss.data[0]

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            batch_ds = self.ds_batcher.next_batch()
            loss = self.train_one_batch(batch, iter, batch_ds)
            loss = loss.cpu()

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 5
            if iter % print_interval == 0:
                print('steps %d , loss: %f' % (iter, loss))
                start = time.time()
            if iter % 50000 == 0:
                self.save_model(running_avg_loss, iter)
Пример #27
0
class BeamSearch(object):
    def __init__(self, model_file_path):
        model_name = os.path.basename(model_file_path)
        self._decode_dir = os.path.join(config.log_root,
                                        'decode_%s' % (model_name))
        self._rouge_ref_dir = os.path.join(self._decode_dir, 'rouge_ref')
        self._rouge_dec_dir = os.path.join(self._decode_dir, 'rouge_dec_dir')
        for p in [self._decode_dir, self._rouge_ref_dir, self._rouge_dec_dir]:
            if not os.path.exists(p):
                os.mkdir(p)

        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.decode_data_path,
                               self.vocab,
                               mode='decode',
                               batch_size=config.beam_size,
                               single_pass=True)
        time.sleep(15)

        self.model = Model(model_file_path, is_eval=True)

    def sort_beams(self, beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def decode(self):
        start = time.time()
        counter = 0
        batch = self.batcher.next_batch()
        while batch is not None:
            # Run beam search to get best Hypothesis
            best_summary = self.beam_search(batch)

            # Extract the output ids from the hypothesis and convert back to words
            output_ids = [int(t) for t in best_summary.tokens[1:]]
            decoded_words = data.outputids2words(
                output_ids, self.vocab,
                (batch.art_oovs[0] if config.pointer_gen else None))

            # Remove the [STOP] token from decoded_words, if necessary
            try:
                fst_stop_idx = decoded_words.index(data.STOP_DECODING)
                decoded_words = decoded_words[:fst_stop_idx]
            except ValueError:
                decoded_words = decoded_words

            original_abstract_sents = batch.original_abstracts_sents[0]

            write_for_rouge(original_abstract_sents, decoded_words, counter,
                            self._rouge_ref_dir, self._rouge_dec_dir)
            counter += 1
            if counter % 1000 == 0:
                print('%d example in %d sec' % (counter, time.time() - start))
                start = time.time()

            batch = self.batcher.next_batch()

        print("Decoder has finished reading dataset for single_pass.")
        print("Now starting ROUGE eval...")
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)

    def beam_search(self, batch):
        #batch should have only one example
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_0, coverage_t_0 = \
            get_input_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_0 = self.model.reduce_state(encoder_hidden)

        dec_h, dec_c = s_t_0  # 1 x 2*hidden_size
        dec_h = dec_h.squeeze()
        dec_c = dec_c.squeeze()

        #decoder batch preparation, it has beam_size example initially everything is repeated
        beams = [
            Beam(tokens=[self.vocab.word2id(data.START_DECODING)],
                 log_probs=[0.0],
                 state=(dec_h[0], dec_c[0]),
                 context=c_t_0[0],
                 coverage=(coverage_t_0[0] if config.is_coverage else None))
            for _ in range(config.beam_size)
        ]
        results = []
        steps = 0
        while steps < config.max_dec_steps and len(results) < config.beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = [t if t < self.vocab.size() else self.vocab.word2id(data.UNKNOWN_TOKEN) \
                             for t in latest_tokens]
            y_t_1 = Variable(torch.LongTensor(latest_tokens))
            if use_cuda:
                y_t_1 = y_t_1.cuda()
            all_state_h = []
            all_state_c = []

            all_context = []

            for h in beams:
                state_h, state_c = h.state
                all_state_h.append(state_h)
                all_state_c.append(state_c)

                all_context.append(h.context)

            s_t_1 = (torch.stack(all_state_h,
                                 0).unsqueeze(0), torch.stack(all_state_c,
                                                              0).unsqueeze(0))
            c_t_1 = torch.stack(all_context, 0)

            coverage_t_1 = None
            if config.is_coverage:
                all_coverage = []
                for h in beams:
                    all_coverage.append(h.coverage)
                coverage_t_1 = torch.stack(all_coverage, 0)

            final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage_t_1, steps)
            log_probs = torch.log(final_dist)
            topk_log_probs, topk_ids = torch.topk(log_probs,
                                                  config.beam_size * 2)

            dec_h, dec_c = s_t
            dec_h = dec_h.squeeze()
            dec_c = dec_c.squeeze()

            all_beams = []
            num_orig_beams = 1 if steps == 0 else len(beams)
            for i in range(num_orig_beams):
                h = beams[i]
                state_i = (dec_h[i], dec_c[i])
                context_i = c_t[i]
                coverage_i = (coverage_t[i] if config.is_coverage else None)

                for j in range(config.beam_size *
                               2):  # for each of the top 2*beam_size hyps:
                    new_beam = h.extend(token=topk_ids[i, j].item(),
                                        log_prob=topk_log_probs[i, j].item(),
                                        state=state_i,
                                        context=context_i,
                                        coverage=coverage_i)
                    all_beams.append(new_beam)

            beams = []
            for h in self.sort_beams(all_beams):
                if h.latest_token == self.vocab.word2id(data.STOP_DECODING):
                    if steps >= config.min_dec_steps:
                        results.append(h)
                else:
                    beams.append(h)
                if len(beams) == config.beam_size or len(
                        results) == config.beam_size:
                    break

            steps += 1

        if len(results) == 0:
            results = beams

        beams_sorted = self.sort_beams(results)

        return beams_sorted[0]
Пример #28
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adam(
            params, lr=initial_lr
        )  #Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch_list, enc_padding_mask_list, enc_lens_list, enc_batch_extend_vocab_list, extra_zeros_list, c_t_1_list, coverage_list = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs_list = []
        encoder_feature_list = []
        s_t_1 = None
        s_t_1_0 = None
        s_t_1_1 = None
        for enc_batch, enc_lens in zip(enc_batch_list, enc_lens_list):
            sorted_indices = sorted(range(len(enc_lens)),
                                    key=enc_lens.__getitem__)
            sorted_indices.reverse()
            inverse_sorted_indices = [-1 for _ in range(len(sorted_indices))]
            for index, position in enumerate(sorted_indices):
                inverse_sorted_indices[position] = index
            sorted_enc_batch = torch.index_select(
                enc_batch, 0,
                torch.LongTensor(sorted_indices)
                if not use_cuda else torch.LongTensor(sorted_indices).cuda())
            sorted_enc_lens = enc_lens[sorted_indices]
            sorted_encoder_outputs, sorted_encoder_feature, sorted_encoder_hidden = self.model.encoder(
                sorted_enc_batch, sorted_enc_lens)
            encoder_outputs = torch.index_select(
                sorted_encoder_outputs, 0,
                torch.LongTensor(inverse_sorted_indices) if not use_cuda else
                torch.LongTensor(inverse_sorted_indices).cuda())
            encoder_feature = torch.index_select(
                sorted_encoder_feature.view(encoder_outputs.shape), 0,
                torch.LongTensor(inverse_sorted_indices) if not use_cuda else
                torch.LongTensor(inverse_sorted_indices).cuda()).view(
                    sorted_encoder_feature.shape)
            encoder_hidden = tuple([
                torch.index_select(
                    sorted_encoder_hidden[0], 1,
                    torch.LongTensor(inverse_sorted_indices) if not use_cuda
                    else torch.LongTensor(inverse_sorted_indices).cuda()),
                torch.index_select(
                    sorted_encoder_hidden[1], 1,
                    torch.LongTensor(inverse_sorted_indices) if not use_cuda
                    else torch.LongTensor(inverse_sorted_indices).cuda())
            ])
            #encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
            encoder_outputs_list.append(encoder_outputs)
            encoder_feature_list.append(encoder_feature)
            if s_t_1 is None:
                s_t_1 = self.model.reduce_state(encoder_hidden)
                s_t_1_0, s_t_1_1 = s_t_1
            else:
                s_t_1_new = self.model.reduce_state(encoder_hidden)
                s_t_1_0 = s_t_1_0 + s_t_1_new[0]
                s_t_1_1 = s_t_1_1 + s_t_1_new[1]
            s_t_1 = tuple([s_t_1_0, s_t_1_1])

        #c_t_1_list = [c_t_1]
        #coverage_list = [coverage]

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1_list, attn_dist_list, p_gen, next_coverage_list = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs_list, encoder_feature_list,
                enc_padding_mask_list, c_t_1_list, extra_zeros_list,
                enc_batch_extend_vocab_list, coverage_list, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = 0.0
                for ind in range(len(coverage_list)):
                    step_coverage_loss += torch.sum(
                        torch.min(attn_dist_list[ind], coverage_list[ind]), 1)
                    coverage_list[ind] = next_coverage_list[ind]
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 500
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                start = time.time()
            if iter % 500 == 0:
                self.save_model(running_avg_loss, iter)
Пример #29
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        state = {
            'iter': iter,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = AdagradCustom(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_hidden, max_encoder_output = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)
        if config.use_maxpool_init_ctx:
            c_t_1 = max_encoder_output

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab, coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        clip_grad_norm(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm(self.model.reduce_state.parameters(),
                       config.max_grad_norm)

        self.optimizer.step()

        return loss.data[0]

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)
Пример #30
0
class Evaluate(object):
    def __init__(self, model_file_path, vocab=None):
        if vocab != None:
            self.vocab = vocab
        else:
            self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.eval_data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)
        model_name = os.path.basename(model_file_path)

        eval_dir = os.path.join(config.log_root, 'eval_%s' % (model_name))
        if not os.path.exists(eval_dir):
            os.mkdir(eval_dir)
        self.summary_writer = tf.compat.v1.summary.FileWriter(eval_dir)

        self.model = Model(model_file_path, is_eval=True)

    def eval_one_batch(self, batch):
        cls_ids,sep_ids,enc_batch_ids,enc_batch_segs, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch_ids, enc_batch_segs, cls_ids, sep_ids, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_step_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_step_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        return loss.item()

    def run_eval(self):
        running_avg_loss, iter = 0, 0
        start = time.time()
        batch = self.batcher.next_batch()
        print(
            "-----------------------------------------STARTING EVALATION---------------------------------------"
        )
        with open(config.eval_log, 'a+', encoding='utf-8') as f:
            f.write(
                "-----------------------------------------STARTING EVALATION---------------------------------------"
                + "\n")
        while batch is not None:
            loss = self.eval_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1
            if iter % 20 == 0:
                self.summary_writer.flush()
            print_interval = 100
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start,
                       running_avg_loss))
                start = time.time()
                with open(config.eval_log, 'a+', encoding='utf-8') as f:
                    f.write("Steps: " + str(iter) + "   loss: " +
                            str(running_avg_loss) + "\n")
            if (iter + 1) % config.max_iterations_eval == 0:
                break
            batch = self.batcher.next_batch()
        return running_avg_loss
class Decoder(object):
    def __init__(self):
        self.vocab = Vocab(args.vocab_path, args.vocab_size)
        self.batcher = Batcher(
            args.decode_data_path,
            self.vocab,
            mode='decode',
            batch_size=1,
            single_pass=True)  # support only 1 item at a time
        time.sleep(15)
        vocab_size = self.vocab.size()
        self.beam_size = args.beam_size
        # self.bertClient = BertClient()
        self.encoder = EncoderLSTM(args.hidden_size, self.vocab.size())
        self.decoder = DecoderLSTM(args.hidden_size, self.vocab.size())
        if use_cuda:
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()

        # Prepare the output folder and files
        output_dir = os.path.join(args.logs, "outputs")
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        output_file = os.path.join(output_dir,
                                   "decoder_{}.txt".format(args.output_name))
        self.file = open(output_file, "w+")

    def load_model(self, checkpoint_file):
        print("Loading Checkpoint: ", checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        encoder_state_dict = checkpoint['encoder_state_dict']
        decoder_state_dict = checkpoint['decoder_state_dict']
        self.encoder.load_state_dict(encoder_state_dict)
        self.decoder.load_state_dict(decoder_state_dict)
        self.encoder.eval()
        self.decoder.eval()
        print("Weights Loaded")

    def decode(self):
        batch = self.batcher.next_batch()
        count = 0

        while batch is not None:
            # Preparing
            enc_batch, enc_padding_mask, enc_lens = get_input_from_batch(
                batch, use_cuda=use_cuda)
            dec_batch, target_batch, dec_padding_mask, max_dec_len = get_output_from_batch(
                batch, use_cuda=use_cuda)
            batch_size = len(enc_batch)
            proba_list = [0] * batch_size
            generated_list = [[]]

            # Encoding sentences
            outputs, hidden_state = self.encoder(enc_batch)

            x = torch.LongTensor([2])  # start sequence
            if use_cuda:
                x = x.cuda()
            """ Normal Approach """
            #answers = torch.ones((batch_size, args.max_dec_steps), dtype=torch.long)
            #for t in range(max(args.max_dec_steps, max_dec_len)):
            #    output, hidden_state = self.decoder(x, hidden_state)  # Output: batch * vocab_size (prob.)
            #    idx = torch.argmax(output, dim=1)
            #    answers[:, t] = idx.detach()
            #    x = idx
            """ Beam Approach """
            for t in range(max(args.max_dec_steps, max_dec_len) - 1):
                output, hidden_state = self.decoder(x, hidden_state)

                # For each sentence, find b best answers (beam search)
                states = []  # (probab, generated, hidden_state_index)
                for i, each_decode in enumerate(output):
                    prev_proba = proba_list[i]
                    prev_generated = generated_list[i]
                    arr = each_decode.detach().cpu().numpy(
                    )  # log-probab of each word
                    indices = arr.argsort()[-self.beam_size:][::-1]  #index
                    for idx in indices:
                        proba = arr[idx] + proba_list[i]  # new probab and prev
                        generated = prev_generated.copy()
                        generated.append(idx)
                        states.append((proba, generated, i))

                # Sort for the best generated sequence among all
                states.sort(key=lambda x: x[0], reverse=True)

                # Variables
                new_proba_list = []
                new_generated = []
                new_hidden = torch.Tensor()
                new_cell = torch.Tensor()
                new_x = torch.LongTensor()

                if use_cuda:
                    new_hidden = new_hidden.cuda()
                    new_cell = new_cell.cuda()
                    new_x = new_x.cuda()

                # Select top b sequences
                for state in states[:self.beam_size]:
                    new_proba_list.append(state[0])
                    new_generated.append(state[1])
                    idx = state[2]

                    h_0 = hidden_state[0].squeeze(0)[idx].unsqueeze(0)
                    c_0 = hidden_state[1].squeeze(0)[idx].unsqueeze(0)
                    new_hidden = torch.cat((new_hidden, h_0), dim=0)
                    new_cell = torch.cat((new_cell, c_0), dim=0)
                    generated_idx = torch.LongTensor([state[1][-1]])
                    if use_cuda:
                        generated_idx = generated_idx.cuda()

                    new_x = torch.cat((new_x, generated_idx))

                # Save the list
                proba_list = new_proba_list
                generated_list = new_generated
                hidden_state = (new_hidden.unsqueeze(0), new_cell.unsqueeze(0))
                x = new_x

            # Convert from id to word
            # answer = answers[0].numpy()
            answer = new_generated[0]
            sentence = ids2words(answer, self.vocab)
            self.file.write("{}\n".format(sentence))
            print("Writing line #{} to file ...".format(count + 1))
            self.file.flush()
            sys.stdout.flush()

            count += 1
            batch = self.batcher.next_batch()
Пример #32
0
class Validate(object):
    def __init__(self, data_path, batch_size=config.batch_size):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(data_path,
                               self.vocab,
                               mode='eval',
                               batch_size=batch_size,
                               single_pass=True)

        time.sleep(5)

    def setup_valid(self):
        self.model = Model()
        self.model = get_cuda(self.model)
        checkpoint = T.load(config.load_model_path)
        self.model.load_state_dict(checkpoint["model_dict"])

    def validate_batch(self):

        self.setup_valid()
        batch = self.batcher.next_batch()
        start_id = self.vocab.word2id(data.START_DECODING)
        end_id = self.vocab.word2id(data.STOP_DECODING)
        unk_id = self.vocab.word2id(data.UNKNOWN_TOKEN)
        decoded_sents = []
        original_sents = []
        rouge = Rouge()
        while batch is not None:
            enc_batch, enc_lens, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t_1 = get_enc_data(
                batch)

            with T.autograd.no_grad():
                enc_batch = self.model.embeds(enc_batch)
                enc_out, enc_hidden = self.model.encoder(enc_batch, enc_lens)

            with T.autograd.no_grad():
                pred_ids = beam_search_on_batch(enc_hidden, enc_out,
                                                enc_padding_mask, c_t_1,
                                                extra_zeros,
                                                enc_batch_extend_vocab,
                                                self.model, start_id, end_id,
                                                unk_id)

            for i in range(len(pred_ids)):
                decoded_words = data.outputids2words(pred_ids[i], self.vocab,
                                                     batch.art_oovs[i])
                if len(decoded_words) < 2:
                    decoded_words = "xxx"
                else:
                    decoded_words = " ".join(decoded_words)
                decoded_sents.append(decoded_words)
                tar = batch.original_abstracts[i]
                original_sents.append(tar)

            batch = self.batcher.next_batch()

        load_file = config.load_model_path.split("/")[-1]

        scores = rouge.get_scores(decoded_sents, original_sents, avg=True)
        rouge_l = scores["rouge-l"]["f"]
        print(load_file, "rouge_l:", "%.4f" % rouge_l)