Esempio n. 1
0
def sample_sequence(caption, history, tokenizer, model, args, current_output=None, video=None):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    if current_output is None:
        current_output = []
    for i in range(args.max_length):
        instance, sequence = build_input_from_segments(caption, history, current_output, tokenizer, with_eos=False, drop_caption=False)

        input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0)
        token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0)
        input_embs = model.transformer.wte(input_ids)
        if video is not None:
            input_embs = torch.cat([model.video_ff(video), input_embs], dim=1)
            token_type_ids = torch.cat([torch.ones((1, video.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids], dim=1)

        logits = model(input_embs, token_type_ids=token_type_ids)
        if "gpt2" == args.model:
            logits = logits[0]
        logits = logits[0, -1, :] / args.temperature
        logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p)
        probs = F.softmax(logits, dim=-1)

        prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1)
        if i < args.min_length and prev.item() in special_tokens_ids:
            while prev.item() in special_tokens_ids:
                prev = torch.multinomial(probs, num_samples=1)

        if prev.item() in special_tokens_ids:
            break
        current_output.append(prev.item())

    return current_output
Esempio n. 2
0
def beam_search(caption, history, tokenizer, model, args, current_output=None, video=None):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    if current_output is None:
        current_output = []
    hyplist = [([], 0., current_output)]
    best_state = None
    comp_hyplist = []
    
    for i in range(args.max_length):
        new_hyplist = []
        argmin = 0
        for out, lp, st in hyplist:
            instance, sequence = build_input_from_segments(caption, history, st, tokenizer, with_eos=False, drop_caption=False)

            input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0)
            token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0)
            input_embs = model.transformer.wte(input_ids)
            if video is not None:
                input_embs = torch.cat([model.video_ff(video), input_embs], dim=1)
                token_type_ids = torch.cat([torch.ones((1, video.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids], dim=1)

            logits = model(input_embs, token_type_ids=token_type_ids)
            if "gpt2" == args.model:
                logits = logits[0]
            logp = F.log_softmax(logits, dim=-1)[:, -1, :]
            lp_vec = logp.cpu().data.numpy() + lp
            lp_vec = np.squeeze(lp_vec)
            if i >= args.min_length:
                new_lp = lp_vec[tokenizer.eos_token_id] + args.penalty * (len(out) + 1)
                comp_hyplist.append((out, new_lp))
                if best_state is None or best_state < new_lp:
                    best_state = new_lp
            count = 1
            for o in np.argsort(lp_vec)[::-1]:
                if o == tokenizer.unk_token_id or o == tokenizer.eos_token_id:
                    continue
                new_lp = lp_vec[o]
                if len(new_hyplist) == args.beam_size:
                    if new_hyplist[argmin][1] < new_lp:
                        new_st = copy.deepcopy(st)
                        new_st.append(int(o))
                        new_hyplist[argmin] = (out + [o], new_lp, new_st)
                        argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
                    else:
                        break
                else:
                    new_st = copy.deepcopy(st)
                    new_st.append(int(o))
                    new_hyplist.append((out + [o], new_lp, new_st))
                    if len(new_hyplist) == args.beam_size:
                        argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
                count += 1
        hyplist = new_hyplist 
    if len(comp_hyplist) > 0:
        maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
        return maxhyps
    else:
        return [([], 0)]
Esempio n. 3
0
def generate_responce(model, dialog_list, id2feature, tokenizer):
    bos, eos, speaker1, speaker2, img, tag = tokenizer.convert_tokens_to_ids(
        SPECIAL_TOKENS[:-1])
    with torch.no_grad():
        for dialog in dialog_list:
            if len(dialog['history']) < 2:
                continue
            history = copy.deepcopy(dialog['history'][:-1])
            answer = copy.deepcopy(dialog['history'][-1])
            for i in range(len(history)):
                if 'img_id' in history[i].keys():
                    history[i]['img_id'] = id2feature[history[i]['img_id']]
            if 'img_id' in answer.keys():
                answer['img_id'] = id2feature[answer['img_id']]
            history_txt, history_img, token_type_ids, _ = build_input_from_segments(
                history, answer, tokenizer)
            if history_txt[-1] == tag:
                history_txt[-1] = img
            if answer['speaker_id'] == '[speaker1]':
                history_txt += [speaker2]
                token_type_ids += [speaker2] * 2
            else:
                history_txt += [speaker1]
                token_type_ids += [speaker1] * 2
            history_txt += [bos]

            history_txt = torch.LongTensor(history_txt)
            history_img = torch.from_numpy(np.array(history_img)).float()
            token_type_ids = torch.Tensor(token_type_ids).long()

            history_txt, history_img, token_type_ids = history_txt.to(
                device), history_img.to(device), token_type_ids.to(device)
            history_txt_embs = model.transformer.wte(history_txt)
            history_img_embs = model.image_off(history_img).squeeze(1)

            input_embs, img_features = input_construct(history_txt_embs,
                                                       history_img_embs,
                                                       token_type_ids,
                                                       tokenizer)
            input_embs, img_features = input_embs.to(device), img_features.to(
                device)
            print(input_embs.size(), token_type_ids.size())

            # res = sample_sequence(input_embs, token_type_ids, model, tokenizer)
            res = greedy_decode(input_embs, token_type_ids, model, tokenizer)
            #res = tokenizer.decode(res, skip_special_tokens=True)
            #print(res)
            break
Esempio n. 4
0
def greedy_decode(caption,
                  history,
                  tokenizer,
                  model,
                  args,
                  current_output=None,
                  video=None):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    ys = []

    for i in range(args.max_length):
        instance, sequence = build_input_from_segments(caption,
                                                       history,
                                                       ys,
                                                       tokenizer,
                                                       with_eos=False,
                                                       drop_caption=False)

        input_ids = torch.tensor(instance["input_ids"],
                                 device=args.device).unsqueeze(0)
        token_type_ids = torch.tensor(instance["token_type_ids"],
                                      device=args.device).unsqueeze(0)
        input_embs = model.transformer.wte(input_ids)
        if video is not None:
            input_embs = torch.cat([model.video_ff(video), input_embs], dim=1)
            token_type_ids = torch.cat([
                torch.ones((1, video.size(1))).long().cuda() *
                tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]),
                token_type_ids
            ],
                                       dim=1)

        logits = model(input_embs, token_type_ids=token_type_ids)
        if "gpt2" == args.model:
            logits = logits[0][0]
        logits = logits.cpu().data.numpy()
        next_word = np.argsort(logits[-1])[-1]
        if next_word == special_tokens_ids[1]:
            break
        ys.append(next_word)
    return ys