Esempio n. 1
0
 def detail_forward(self, incoming):
     inp = Storage()
     batch_size = inp.batch_size = incoming.data.batch_size
     inp.init_h = incoming.conn.init_h
     inp.post = incoming.hidden.h
     inp.post_length = incoming.data.post_length
     inp.embLayer = incoming.resp.embLayer
     inp.dm = self.param.volatile.dm
     inp.max_sent_length = self.args.max_sent_length
     inp.data = incoming.data
     incoming.gen = gen = Storage()
     self.freerun(inp, gen)
     dm = self.param.volatile.dm
     w_o = gen.w_o.detach().cpu().numpy()
     incoming.result.resp_str = resp_str = \
         [" ".join(dm.convert_ids_to_tokens(w_o[:, i].tolist()))
          for i in range(batch_size)]
     incoming.result.golden_str = golden_str = \
         [" ".join(dm.convert_ids_to_tokens(incoming.data.resp[:, i].detach().cpu().numpy().tolist()))
          for i in range(batch_size)]
     incoming.result.post_str = post_str = \
         [" ".join(dm.convert_ids_to_tokens(incoming.data.post[:, i].detach().cpu().numpy().tolist()))
          for i in range(batch_size)]
     incoming.result.show_str = "\n".join([
         "post: " + a + "\n" + "resp: " + b + "\n" + "golden: " + c + "\n"
         for a, b, c in zip(post_str, resp_str, golden_str)
     ])
Esempio n. 2
0
 def _preprocess_batch(self, data):
     incoming = Storage()
     incoming.data = data = Storage(data)
     data.batch_size = data.post.shape[0]
     # length * batch_size
     data.post = cuda(torch.LongTensor(data.post.transpose(1, 0)))
     # length * batch_size
     data.resp = cuda(torch.LongTensor(data.resp.transpose(1, 0)))
     return incoming
Esempio n. 3
0
    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        # Perplexity
        metric1 = dm.get_teacher_forcing_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval teacher-forcing")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
                gen_log_prob = nn.functional.log_softmax(incoming.gen.w, -1)
            data = incoming.data
            data.resp_allvocabs = LongTensor(incoming.data.resp_allvocabs)
            data.resp_length = incoming.data.resp_length
            data.gen_log_prob = gen_log_prob.transpose(1, 0)
            metric1.forward(data)
        res = metric1.close()

        metric = dm.get_inference_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval free-run")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            data = incoming.data
            data.gen = incoming.gen.w_o.detach().cpu().numpy().transpose(1, 0)
            metric.forward(data)
        res.update(metric.close())

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with open(filename, 'w') as f:
            print("here")
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, str):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            for i in range(len(res['post'])):
                f.write("post:\t%s\n" % " ".join(res['post'][i]))
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n" % " ".join(res['gen'][i]))
            f.flush()
        logging.info("result output to %s.", filename)

        return {
            key: val
            for key, val in res.items() if isinstance(val, (str, int, float))
        }
Esempio n. 4
0
    def train(self, batch_num):
        args = self.param.args
        dm = self.param.volatile.dm
        datakey = 'train'

        for i in range(batch_num):
            self.now_batch += 1
            incoming = self.get_next_batch(dm, datakey)
            incoming.args = Storage()

            if (i + 1) % args.batch_num_per_gradient == 0:
                self.zero_grad()

            self.net.forward(incoming)

            loss = incoming.result.loss
            self.trainSummary(self.now_batch, storage_to_list(incoming.result))
            logging.info("batch %d : gen loss=%f", self.now_batch,
                         loss.detach().cpu().numpy())

            loss.backward()

            if (i + 1) % args.batch_num_per_gradient == 0:
                nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip)
                self.optimizer.step()
Esempio n. 5
0
    def evaluate(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        dm.restart(key, args.batch_size, shuffle=False)

        result_arr = []
        while True:
            incoming = self.get_next_batch(dm, key, restart=False)
            if incoming is None:
                break
            incoming.args = Storage()

            with torch.no_grad():
                self.net.forward(incoming)
            result_arr.append(incoming.result)

        detail_arr = Storage()
        for i in args.show_sample:
            index = [i * args.batch_size + j for j in range(args.batch_size)]
            incoming = self.get_select_batch(dm, key, index)
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            detail_arr["show_str%d" % i] = incoming.result.show_str

        detail_arr.update(
            {key: get_mean(result_arr, key)
             for key in result_arr[0]})
        detail_arr.perplexity_avg_on_batch = np.exp(detail_arr.word_loss)
        return detail_arr
Esempio n. 6
0
 def forward(self, incoming):
     incoming.result = Storage()
     self.embLayer.forward(incoming)
     self.postEncoder.forward(incoming)
     self.connectLayer.forward(incoming)
     self.genNetwork.forward(incoming)
     incoming.result.loss = incoming.result.word_loss
     if torch.isnan(incoming.result.loss).detach().cpu().numpy() > 0:
         logging.info("Nan detected")
         logging.info(incoming.result)
         raise FloatingPointError("Nan detected")
Esempio n. 7
0
 def forward(self, incoming):
     inp = Storage()
     inp.resp_length = incoming.data.resp_length
     inp.embedding = incoming.resp.embedding
     inp.post = incoming.hidden.h
     inp.post_length = incoming.data.post_length
     inp.init_h = incoming.conn.init_h
     incoming.gen = gen = Storage()
     self.teacherForcing(inp, gen)
     gen.w = self.wLinearLayer(gen.h)
     if self.args.pointer_gen:
         # calc distribution on voc with pointer gen
         voc_att = self.sum_on_voc_index(gen.att, gen.w, incoming.data.post)
         gen.w = gen.p * gen.w + (1 - gen.p) * voc_att
     w_o_f = flattenSequence(gen.w, incoming.data.resp_length - 1)
     data_f = flattenSequence(incoming.data.resp[1:],
                              incoming.data.resp_length - 1)
     incoming.result.word_loss = self.lossCE(w_o_f, data_f)
     if self.args.coverage:
         cov_loss = flattenSequence(
             torch.min(gen.cov, gen.att).sum(2),
             incoming.data.resp_length - 1) * self.args.cov_loss_wt
         incoming.result.word_loss += cov_loss.mean()
     incoming.result.perplexity = torch.exp(incoming.result.word_loss)
Esempio n. 8
0
 def forward(self, incoming):
     incoming.post = Storage()
     incoming.post.embedding = self.embLayer(incoming.data.post)
     incoming.resp = Storage()
     incoming.resp.embedding = self.embLayer(incoming.data.resp)
     incoming.resp.embLayer = self.embLayer
Esempio n. 9
0
 def detail_forward(self, incoming):
     incoming.result = Storage()
     self.embLayer.forward(incoming)
     self.postEncoder.forward(incoming)
     self.connectLayer.forward(incoming)
     self.genNetwork.detail_forward(incoming)
Esempio n. 10
0
def run(*argv):
    parser = argparse.ArgumentParser(
        description=
        "A pytorch implementation of the paper Get to the point (https://arxiv.org/abs/1704.04368)"
    )

    # Model parameters
    parser.add_argument('--restore',
                        type=str,
                        default=None,
                        help='Checkpoints name to load. \
                        "NAME_last" for the last checkpoint of model named NAME. "NAME_best" means the best checkpoint. \
                        You can also use "last" and "best", by default use last model you run. \
                        It can also be an url started with "http". \
                        Attention: "NAME_last" and "NAME_best" are not guaranteed to work when 2 models with same name run in the same time. \
                        "last" and "best" are not guaranteed to work when 2 models run in the same time.\
                        Default: None (don\'t load anything)')
    parser.add_argument(
        '--name',
        type=str,
        help=
        "Name for experiment. Logs will be saved in a directory with this name, under log_root.",
        default=time.ctime(time.time()).replace(" ", "_"))
    parser.add_argument('--datapath',
                        type=str,
                        default="../data/#CNN",
                        help='path to the dataset')
    parser.add_argument('--dataset',
                        type=str,
                        default='CNN',
                        help='Dataloader class')
    parser.add_argument('--mode',
                        type=str,
                        default="train",
                        help='must be one of train/eval')
    parser.add_argument(
        '--model_dir',
        type=str,
        default="./model",
        help='Checkpoints directory for model. Default: ./model')
    parser.add_argument('--checkpoint_steps', type=int, help="", default=20)
    parser.add_argument('--checkpoint_max_to_keep',
                        type=int,
                        help="",
                        default=5)
    parser.add_argument(
        '--log_dir',
        type=str,
        default="./tensorboard",
        help='Log directory for tensorboard. Default: ./tensorboard')
    parser.add_argument('--cuda', type=bool, help="", default=False)
    parser.add_argument('--cache', type=str, help="", default="")
    parser.add_argument('--seed', type=int, help='', default=42)
    parser.add_argument('--restore_optimizer',
                        type=bool,
                        default=True,
                        help='')
    parser.add_argument('--debug',
                        action='store_true',
                        help='Enter debug mode (using ptvsd).')

    # Network parameters
    parser.add_argument(
        '--pointer_gen',
        type=bool,
        help=
        'If True, use pointer-generator model. If False, use baseline model.',
        default=True)
    parser.add_argument(
        '--coverage',
        type=bool,
        help=
        'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.',
        default=False)
    parser.add_argument('--embedding_size',
                        type=int,
                        help='dimension of word embeddings',
                        default=128)
    parser.add_argument('--eh_size',
                        type=int,
                        help='dimension of RNN encoder hidden states',
                        default=256)
    parser.add_argument('--dh_size',
                        type=int,
                        help='dimension of RNN decoder hidden states',
                        default=256)
    parser.add_argument('--epochs',
                        type=int,
                        default=30,
                        help="Epoch for training. Default: 100")
    parser.add_argument('--batch_per_epoch',
                        type=int,
                        default=2,
                        help="Batches per epoch. Default: 1500")
    parser.add_argument('--batch_num_per_gradient',
                        type=int,
                        default=2,
                        help="")
    parser.add_argument('--batch_size',
                        type=int,
                        help='minibatch size',
                        default=16)
    parser.add_argument('--grad_clip', type=int, help='', default=5)
    parser.add_argument('--lr', type=float, help='learning rate', default=.015)
    parser.add_argument('--droprate', type=float, help="", default=.0)
    parser.add_argument('--batchnorm', type=bool, help="", default=False)
    parser.add_argument('--max_sent_length', type=int, help="", default=100)
    parser.add_argument('--max_doc_length', type=int, help="", default=400)
    parser.add_argument('--min_vocab_times', type=int, help="", default=50)

    # Decode parameters
    parser.add_argument(
        '--cov_loss_wt',
        type=float,
        help=
        'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.',
        default=1.)
    parser.add_argument('--show_sample', type=list, help="", default=[0])
    parser.add_argument(
        '--decode_mode',
        type=str,
        choices=['max', 'sample', 'gumbel', 'samplek', 'beam'],
        default='beam',
        help=
        'The decode strategy when freerun. Choices: max, sample, gumbel(=sample), \
                        samplek(sample from topk), beam(beamsearch). Default: beam'
    )
    parser.add_argument(
        '--top_k',
        type=int,
        default=10,
        help='The top_k when decode_mode == "beam" or "samplek"')
    parser.add_argument(
        '--out_dir',
        type=str,
        default="./output",
        help='Output directory for test output. Default: ./output')
    parser.add_argument(
        '--length_penalty',
        type=float,
        default=0.7,
        help='The beamsearch penalty for short sentences. The penalty will get\
                        larger when this becomes smaller.')

    args = Storage()
    for key, val in vars(parser.parse_args(argv)).items():
        args[key] = val

    random.seed(args.seed), torch.manual_seed(args.seed), np.random.seed(
        args.seed)

    main(args)
Esempio n. 11
0
def main(args, load_exclude_set=[], restoreCallback=None):

    logging.basicConfig(
        filename=0,
        level=logging.DEBUG,
        format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',
        datefmt='%H:%M:%S')

    if args.debug:
        debug()

    logging.info(json.dumps(args, indent=2))
    cuda_init(0, args.cuda)

    volatile = Storage()
    volatile.load_exclude_set = load_exclude_set
    volatile.restoreCallback = restoreCallback

    data_arg = Storage()
    data_arg.file_id = args.datapath
    data_arg.min_vocab_times = args.min_vocab_times
    data_arg.max_doc_length = args.max_doc_length
    data_arg.invalid_vocab_times = args.min_vocab_times

    data_class = TextSummarization.load_class(args.dataset)
    volatile.dm = (try_cache(data_class,
                             (*data_arg), args.cache_dir, data_class.__name__)
                   if args.cache else data_class(**data_arg))

    param = Storage()
    param.args = args
    param.volatile = volatile

    model = PointerGen(param)
    if args.mode == "train":
        model.train_process()
    elif args.mode == "eval":
        test_res = model.test_process()
        json.dump(test_res, open("./result.json", "w"))
    else:
        raise ValueError("Unknown mode")
Esempio n. 12
0
    def _beamsearch(self,
                    inp,
                    top_k,
                    nextStep,
                    wLinearLayerCallback,
                    input_callback=None,
                    no_unk=True,
                    length_penalty=0.7):
        # inp contains: batch_size, dm, embLayer, max_sent_length, [init_h]
        # input_callback(i, embedding):   if you want to change word embedding at pos i, override this function
        # nextStep(embedding, flag):  pass embedding to RNN and get gru_h, flag indicates i th sentence is end when flag[i]==1
        # wLinearLayerCallback(gru_h): input gru_h and give logits on vocablist

        # output: w_o emb length

        #start_id = inp.dm.go_id if no_unk else 0

        batch_size = inp.batch_size
        dm = inp.dm
        first_emb = inp.embLayer(LongTensor([dm.go_id
                                             ])).repeat(batch_size, top_k, 1)
        w_pro = []
        w_o = []
        emb = []
        flag = zeros(batch_size, top_k).int()
        EOSmet = []
        score = zeros(batch_size, top_k)
        score[:, 1:] = -1e9
        now_length = zeros(batch_size, top_k)
        back_index = []
        regroup = LongTensor([i for i in range(top_k)]).repeat(batch_size, 1)

        next_emb = first_emb
        #nextStep = self.init_forward(batch_size, inp.get("init_h", None))

        for i in range(inp.max_sent_length):
            now = next_emb
            if input_callback:
                now = input_callback(i, now)

            # batch_size, top_k, hidden_size

            gru_h = nextStep(now, flag, regroup=regroup)
            w = wLinearLayerCallback(gru_h,
                                     inp)  # batch_size, top_k, vocab_size

            if no_unk:
                w[:, :, dm.unk_id] = -1e9
            w = w.log_softmax(dim=-1)
            w_pro.append(w.exp())

            new_score = (score.unsqueeze(-1) + w * (1-flag.float()).unsqueeze(-1)) / \
                ((now_length.float() + 1 - flag.float()).unsqueeze(-1) ** length_penalty)
            new_score[:, :, 1:] = new_score[:, :, 1:] - \
                flag.float().unsqueeze(-1) * 1e9
            _, index = new_score.reshape(batch_size, -1).topk(
                top_k, dim=-1, largest=True, sorted=True)  # batch_size, top_k

            new_score = (score.unsqueeze(-1) + w *
                         (1 - flag.float()).unsqueeze(-1)).reshape(
                             batch_size, -1)

            score = torch.gather(new_score, dim=1, index=index)

            vocab_size = w.shape[-1]
            regroup = index / vocab_size  # batch_size, top_k

            back_index.append(regroup)
            w = torch.fmod(index, vocab_size)  # batch_size, top_k

            flag = torch.gather(flag, dim=1, index=regroup)

            now_length = torch.gather(now_length, dim=1,
                                      index=regroup) + 1 - flag.float()

            w_x = w.clone()
            w_x[w_x >= dm.vocab_size] = dm.unk_id

            next_emb = inp.embLayer(w_x)
            w_o.append(w)
            emb.append(next_emb)

            EOSmet.append(flag)

            flag = flag | (w == dm.eos_id).int()
            if torch.sum(flag).detach().cpu().numpy() == batch_size * top_k:
                break

        # back tracking
        gen = Storage()
        back_EOSmet = []
        gen.w_o = []
        gen.emb = []
        now_index = LongTensor([i for i in range(top_k)]).repeat(batch_size, 1)

        for i, index in reversed(list(enumerate(back_index))):
            gen.w_o.append(torch.gather(w_o[i], dim=1, index=now_index))
            gen.emb.append(
                torch.gather(emb[i],
                             dim=1,
                             index=now_index.unsqueeze(-1).expand_as(emb[i])))
            back_EOSmet.append(torch.gather(EOSmet[i], dim=1, index=now_index))
            now_index = torch.gather(index, dim=1, index=now_index)

        back_EOSmet = 1 - torch.stack(list(reversed(back_EOSmet)))
        gen.w_o = torch.stack(list(reversed(gen.w_o))) * back_EOSmet.long()
        gen.emb = torch.stack(list(reversed(gen.emb))) * \
            back_EOSmet.float().unsqueeze(-1)
        gen.length = torch.sum(back_EOSmet, 0).detach().cpu().numpy()

        return gen
Esempio n. 13
0
    def _freerun(self,
                 inp,
                 nextStep,
                 wLinearLayerCallback,
                 mode='max',
                 input_callback=None,
                 no_unk=True,
                 top_k=10):
        # inp contains: batch_size, dm, embLayer, max_sent_length, [init_h]
        # input_callback(i, embedding):   if you want to change word embedding at pos i, override this function
        # nextStep(embedding, flag):  pass embedding to RNN and get gru_h, flag indicates i th sentence is end when flag[i]==1
        # wLinearLayerCallback(gru_h): input gru_h and give a probability distribution on vocablist

        # output: w_o emb length

        start_id = inp.dm.go_id if no_unk else 0

        batch_size = inp.batch_size
        dm = inp.dm

        first_emb = inp.embLayer(LongTensor([dm.go_id])).repeat(batch_size, 1)

        gen = Storage()
        gen.w_pro = []
        gen.w_o = []
        gen.emb = []
        flag = zeros(batch_size).int()
        EOSmet = []

        next_emb = first_emb
        #nextStep = self.init_forward(batch_size, inp.get("init_h", None))

        for i in range(inp.max_sent_length):
            now = next_emb
            if input_callback:
                now = input_callback(i, now)

            gru_h = nextStep(now, flag)
            #if isinstance(gru_h, tuple):
            #    gru_h = gru_h[0]

            w = wLinearLayerCallback(gru_h, inp)
            gen.w_pro.append(w.softmax(dim=-1))
            # TODO: didn't consider copynet

            if mode == "max":
                w = torch.argmax(w[:, start_id:], dim=1) + start_id
                next_emb = inp.embLayer(w)
            elif mode == "gumbel" or mode == "sample":
                w_onehot = gumbel_max(w[:, start_id:])
                w = torch.argmax(w_onehot, dim=1) + start_id
                next_emb = torch.sum(
                    torch.unsqueeze(w_onehot, -1) *
                    inp.embLayer.weight[start_id:], 1)
            elif mode == "samplek":
                _, index = w[:,
                             start_id:].topk(top_k,
                                             dim=-1,
                                             largest=True,
                                             sorted=True)  # batch_size, top_k

                mask = torch.zeros_like(w[:,
                                          start_id:]).scatter_(-1, index, 1.0)
                w_onehot = gumbel_max_with_mask(w[:, start_id:], mask)
                w = torch.argmax(w_onehot, dim=1) + start_id
                next_emb = torch.sum(
                    torch.unsqueeze(w_onehot, -1) *
                    inp.embLayer.weight[start_id:], 1)

            gen.w_o.append(w)
            gen.emb.append(next_emb)

            EOSmet.append(flag)
            flag = flag | (w == dm.eos_id).int()
            if torch.sum(flag).detach().cpu().numpy() == batch_size:
                break

        EOSmet = 1 - torch.stack(EOSmet)
        gen.w_o = torch.stack(gen.w_o) * EOSmet.long()
        gen.emb = torch.stack(gen.emb) * EOSmet.float().unsqueeze(-1)
        gen.length = torch.sum(EOSmet, 0).detach().cpu().numpy()

        return gen