예제 #1
0
def launch(model_params, checkpoint_path, device='cuda'):
    print('model_params:\t', model_params)

    max_length = model_params['bptt']

    tokenizer = get_default_tokenizer()

    eos_token = tokenizer.token_to_id('[SEP]')
    eod_token = tokenizer.token_to_id('[DOC_SEP]')
    vocab_size = tokenizer._tokenizer.get_vocab_size()

    assert eos_token is not None, 'Invalid tokenizer files - EOS token cannot be null'

    # Model

    from models import TransformerModel, LSTMModel

    model_type = model_params.get('model_type', 'transformer')
    assert model_type in ['transformer', 'lstm']

    if model_type == 'transformer':
        model = TransformerModel(ntoken=vocab_size, **model_params)
    else:
        model = LSTMModel(ntoken=vocab_size, **model_params)

    model = model.to(device)

    if checkpoint_path and path.exists(checkpoint_path):
        print(f'Loading checkpoint from {checkpoint_path}')
        checkpoint_state = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint_state)

    @torch.no_grad()
    def _generate(input_ids=None,
                  max_length=max_length,
                  do_sample=True,
                  num_beams=5,
                  temperature=1.3,
                  top_k=50,
                  top_p=1.0,
                  repetition_penalty=1.2,
                  eos_token_ids=[eos_token, eod_token],
                  length_penalty=1.0,
                  num_return_sequences=1,
                  vocab_size=vocab_size):
        pad_token_id = 0
        model.eval()

        batch_size = 1
        cur_len = input_ids.shape[1]

        # Expand input to num beams
        input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams,
                                                  cur_len)
        input_ids = input_ids.contiguous().view(batch_size * num_beams,
                                                cur_len)

        # generated hypotheses
        generated_hyps = [
            BeamHypotheses(num_beams,
                           max_length,
                           length_penalty,
                           early_stopping=False) for _ in range(batch_size)
        ]

        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, num_beams),
                                  dtype=torch.float,
                                  device=input_ids.device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)

        # cache compute states
        past = None

        # done sentences
        done = [False for _ in range(batch_size)]

        while cur_len < max_length:

            outputs = model(input_ids.t())
            outputs = outputs.permute(1, 0, 2)
            # print(input_ids)
            # print(torch.argmax(outputs))

            scores = outputs[:, -1, :]

            # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
                for i in range(batch_size * num_beams):
                    for previous_token in set(input_ids[i].tolist()):
                        # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                        if scores[i, previous_token] < 0:
                            scores[i, previous_token] *= repetition_penalty
                        else:
                            scores[i, previous_token] /= repetition_penalty

            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
                if temperature != 1.0:
                    scores = scores / temperature
                # Top-p/top-k filtering
                # min_value = torch.min(scores, dim=-1)[]
                scores = top_k_top_p_filtering(
                    scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
                )  # (batch_size * num_beams, vocab_size)
                # Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)

                try:
                    next_words = torch.multinomial(
                        torch.softmax(scores, dim=-1),
                        num_samples=2,
                        replacement=True)  # (batch_size * num_beams, 2)
                except:
                    print((torch.softmax(scores, dim=-1) > 0).sum())
                    raise ValueError()
                # Compute next scores
                _scores = F.log_softmax(
                    scores, dim=-1)  # (batch_size * num_beams, vocab_size)
                _scores = torch.gather(
                    _scores, -1, next_words)  # (batch_size * num_beams, 2)
                next_scores = _scores + beam_scores[:, None].expand_as(
                    _scores)  # (batch_size * num_beams, 2)
                # Match shape of greedy beam search
                next_words = next_words.view(
                    batch_size, 2 * num_beams)  # (batch_size, 2 * num_beams)
                next_scores = next_scores.view(
                    batch_size, 2 * num_beams)  # (batch_size, 2 * num_beams)
            else:
                # do greedy beam search
                scores = F.log_softmax(
                    scores, dim=-1)  # (batch_size * num_beams, vocab_size)
                assert scores.size() == (batch_size * num_beams, vocab_size)
                # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
                _scores = scores + beam_scores[:, None].expand_as(
                    scores)  # (batch_size * num_beams, vocab_size)
                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
                _scores = _scores.view(
                    batch_size, num_beams *
                    vocab_size)  # (batch_size, num_beams * vocab_size)
                next_scores, next_words = torch.topk(_scores,
                                                     2 * num_beams,
                                                     dim=1,
                                                     largest=True,
                                                     sorted=True)

            assert next_scores.size() == next_words.size() == (batch_size,
                                                               2 * num_beams)

            # next batch beam content
            # list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for batch_ex in range(batch_size):

                # if we are done with this sentence
                done[batch_ex] = done[batch_ex] or generated_hyps[
                    batch_ex].is_done(next_scores[batch_ex].max().item())
                if done[batch_ex]:
                    next_batch_beam.extend([(0, pad_token_id, 0)] *
                                           num_beams)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next words for this sentence
                for idx, score in zip(next_words[batch_ex],
                                      next_scores[batch_ex]):

                    # get beam and word IDs
                    beam_id = idx // vocab_size
                    word_id = idx % vocab_size

                    # end of sentence, or next word
                    if word_id.item(
                    ) in eos_token_ids or cur_len + 1 == max_length:
                        generated_hyps[batch_ex].add(
                            input_ids[batch_ex * num_beams +
                                      beam_id, :cur_len].clone(), score.item())
                    else:
                        next_sent_beam.append(
                            (score, word_id, batch_ex * num_beams + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == num_beams:
                        break

                # update next beam content
                assert len(next_sent_beam
                           ) == 0 if cur_len + 1 == max_length else num_beams
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, pad_token_id, 0)
                                      ] * num_beams  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == num_beams * (batch_ex + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * num_beams
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = input_ids.new([x[1] for x in next_batch_beam])
            beam_idx = input_ids.new([x[2] for x in next_batch_beam])

            # re-order batch
            input_ids = input_ids[beam_idx, :]
            input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)

            # re-order internal states
            if past:
                reordered_past = []
                for layer_past in past:
                    # get the correct batch idx from layer past batch dim
                    # batch dim of `past` and `mems` is at 2nd position
                    reordered_layer_past = [
                        layer_past[:, i].unsqueeze(1).clone().detach()
                        for i in beam_idx
                    ]
                    reordered_layer_past = torch.cat(reordered_layer_past,
                                                     dim=1)
                    # check that shape matches
                    assert reordered_layer_past.shape == layer_past.shape
                    reordered_past.append(reordered_layer_past)
                past = tuple(reordered_past)

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        # visualize hypotheses
        # print([len(x) for x in generated_hyps], cur_len)
        # globals().update( locals() );
        # !import code; code.interact(local=vars())
        # for ii in range(batch_size):
        #     for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
        #         print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
        #     print("")

        # select the best hypotheses
        tgt_len = input_ids.new(batch_size)
        best = []

        for i, hypotheses in enumerate(generated_hyps):
            if len(hypotheses.hyp) == 0:
                continue

            best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)

        # generate target batch
        decoded = input_ids.new(batch_size,
                                tgt_len.max().item()).fill_(pad_token_id)
        for i, hypo in enumerate(best):
            decoded[i, :tgt_len[i] - 1] = hypo
            decoded[i, tgt_len[i] - 1] = eos_token_ids[0]

        return decoded

    model_input = LEADING_TEXT

    while True:
        user_prompt = input(' >>> ')

        if user_prompt == 'exit':
            exit()

        else:
            num_return_sequences = 1

            model_input += ' [P0] ' + user_prompt + ' [SEP] [P1] '

            input_ids = tokenizer.encode(model_input).ids
            input_ids = torch.LongTensor(input_ids).unsqueeze(0)
            input_ids = input_ids.to(device)

            output = _generate(input_ids=input_ids,
                               max_length=min(max_length,
                                              input_ids.size(1) + 40))

            if num_return_sequences != 1:
                output = output.view(batch_size, num_return_sequences, -1)

            response = tokenizer.decode(output[0].cpu().tolist(),
                                        skip_special_tokens=False)

            eod_token = '[DOC_SEP]'

            if eod_token in response:
                response = response[response.index(eod_token):]

            start_token = '[P1]'
            sep_token = '[SEP]'

            if start_token in response:
                start_idx = response.index(start_token) + len(start_token) + 1
                response = response[start_idx:]

            if sep_token in response:
                sep_idx = response.index(sep_token)
                response = response[:sep_idx]

            model_input += response + f' {sep_token} '

            print('Bot: ' + response)
예제 #2
0
파일: inference.py 프로젝트: tbmoon/VUC
def main(args):

    since = time.time()
    output_dir = os.path.join(os.getcwd(), 'outputs')
    os.makedirs(output_dir, exist_ok=True)

    data_loaders = get_dataloader(
        input_dir=args.input_dir,
        which_challenge='3rd_challenge',
        phases=['test'],
        max_frame_length=args.max_frame_length,
        max_vid_label_length=args.max_vid_label_length,
        max_seg_label_length=args.max_seg_label_length,
        rgb_feature_size=args.rgb_feature_size,
        audio_feature_size=args.audio_feature_size,
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    model = TransformerModel(
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        rgb_feature_size=args.rgb_feature_size,
        audio_feature_size=args.audio_feature_size,
        d_rgb=args.d_rgb,
        d_audio=args.d_audio,
        d_model=args.d_model,
        d_ff=args.d_ff,
        d_proj=args.d_proj,
        n_attns = args.n_attns,
        num_classes=args.num_classes,
        dropout=args.dropout)
    model = model.to(device)

    checkpoint = torch.load(os.path.join(os.getcwd(), 'models/model-epoch-04.ckpt'))
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    df_outputs = {i: pd.DataFrame(columns=['vid_id', 'vid_label_pred', 'vid_prob', 'seg_label_pred', 'seg_prob']) \
                      for i in range(1, args.num_classes+1)}

    for idx, (vid_ids, frame_lengths, frame_rgbs, frame_audios, vid_labels, seg_labels, seg_times) \
        in enumerate(data_loaders['test']):           

        if idx%10 == 0:
            print('idx:', idx)

        # frame_rgbs: [batch_size, frame_length, rgb_feature_size]
        # frame_audios: [batch_size, frame_length, audio_feature_size]
        frame_rgbs = frame_rgbs.to(device)
        frame_audios = frame_audios.to(device)
        batch_size = frame_audios.size(0)

        # vid_probs: [batch_size, num_classes]
        # attn_idc: [batch_size, num_classes]
        # scores: [batch_size, max_seg_length, n_attns]
        # attn_weights: [batch_size, max_seg_length, n_attns]
        vid_probs, attn_idc, scores, attn_weights, conv_loss = model(frame_rgbs, frame_audios, device)

        # vid_probs: [batch_size, vid_pred_length]
        # vid_label_preds: [batch_size, vid_pred_length]
        vid_probs, vid_label_preds = torch.topk(vid_probs, args.vid_pred_length)
        vid_label_preds = vid_label_preds + 1

        # attn_idc: [batch_size, num_classes+1]
        zeros = torch.zeros(batch_size, 1).long().to(device)
        attn_idc = torch.cat((zeros, attn_idc), dim=1)

        # selected_attn_idc: [batch_size, vid_pred_length]
        selected_attn_idc = torch.gather(attn_idc, 1, vid_label_preds)

        # attn_weights: [batch_size, n_attns, max_seg_length]
        attn_weights = attn_weights.transpose(1, 2)

        # selected_attn_weights: [batch_size, vid_pred_length, max_seg_length]
        selected_attn_weights = batched_index_select(attn_weights, 1, selected_attn_idc)

        # seg_probs: [batch_size, vid_pred_length, seg_pred_length] 
        # seg_label_preds: [batch_size, vid_pred_length, seg_pred_length] 
        seg_probs, seg_label_preds = torch.topk(selected_attn_weights, args.seg_pred_length)
        seg_label_preds = seg_label_preds + 1

        # seg_prob_min, seg_prob_max: [batch_size, vid_pred_length]
        seg_prob_min, _ = seg_probs.min(dim=2)
        seg_prob_max, _ = seg_probs.max(dim=2)

        # seg_prob_min, seg_prob_max: [batch_size, vid_pred_length, seg_pred_length]
        seg_prob_min = seg_prob_min.unsqueeze(2).expand(batch_size, args.vid_pred_length, args.seg_pred_length)
        seg_prob_max = seg_prob_max.unsqueeze(2).expand(batch_size, args.vid_pred_length, args.seg_pred_length)

        # seg_probs: [batch_size, vid_pred_length, seg_pred_length]
        seg_probs = (seg_probs - seg_prob_min) / (seg_prob_max - seg_prob_min + 1e-6)

        # To save predictions, converted to numpy data.
        vid_probs = vid_probs.cpu().detach().numpy()
        vid_label_preds = vid_label_preds.cpu().numpy()
        seg_probs = seg_probs.cpu().detach().numpy()
        seg_label_preds = seg_label_preds.cpu().numpy()

        for i in range(batch_size):
            for j in range(args.vid_pred_length):
                vid_label_pred = vid_label_preds[i][j]
                df_outputs[vid_label_pred] = df_outputs[vid_label_pred].append(
                    {'vid_id': vid_ids[i],
                     'vid_label_pred': vid_label_pred,
                     'vid_prob': vid_probs[i][j],
                     'seg_label_pred': list(seg_label_preds[i][j]),
                     'seg_prob': list(seg_probs[i][j])}, ignore_index=True)

    for i in range(1, args.num_classes+1):
        df_outputs[i].to_csv(os.path.join(output_dir, '%04d.csv'%i), index=False)

    time_elapsed = time.time() - since
    print('=> Running time in a epoch: {:.0f}h {:.0f}m {:.0f}s'
          .format(time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60))