Exemple #1
0
def intention_test(context_text_encoder: TextEncoder,
                   context_image_encoder: ImageEncoder,
                   context_encoder: ContextEncoder, intention: Intention,
                   test_dataset: Dataset):
    """Intention test.
    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        intention (Intention): Intention.
        test_dataset (Dataset): Test dataset.
    """

    # Test dataset loader.
    test_data_loader = DataLoader(
        test_dataset,
        batch_size=IntentionTestConfig.batch_size,
        shuffle=False,
        num_workers=IntentionTestConfig.num_data_loader_workers)

    sum_accuracy = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    intention.eval()

    with torch.no_grad():
        for batch_id, valid_data in enumerate(test_data_loader):

            texts, text_lengths, images, utter_types = valid_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, _ = encode_context(context_text_encoder,
                                        context_image_encoder, context_encoder,
                                        texts, text_lengths, images)
            # (batch_size, context_vector_size)

            intent_prob = intention(context)
            # (batch_size, utterance_type_size)

            intentions = torch.argmax(intent_prob, dim=1)
            eqs = torch.eq(intentions, utter_types)
            num_correct = torch.sum(eqs).item()
            accuracy = num_correct * 1.0 / eqs.size(0)
            sum_accuracy += accuracy

            # Print.
            print('pred:', intentions)
            print('true:', utter_types)
            print('# correct:', num_correct)
            print('accuracy:', accuracy)
            print('total accuracy:', sum_accuracy / (batch_id + 1))
Exemple #2
0
def intention_valid(context_text_encoder: TextEncoder,
                    context_image_encoder: ImageEncoder,
                    context_encoder: ContextEncoder, intention: Intention,
                    valid_dataset: Dataset):
    """Intention valid.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        intention (Intention): Intention.
        valid_dataset (Dataset): Valid dataset.

    """

    # Valid dataset loader.
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=IntentionValidConfig.batch_size,
        shuffle=True,
        num_workers=IntentionValidConfig.num_data_loader_workers)

    sum_loss = 0
    sum_accuracy = 0
    num_batches = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    intention.eval()

    with torch.no_grad():
        for batch_id, valid_data in enumerate(valid_data_loader):
            # Only valid `ValidConfig.num_batches` batches.
            if batch_id >= IntentionValidConfig.num_batches:
                break
            num_batches += 1

            texts, text_lengths, images, utter_types = valid_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, _ = encode_context(context_text_encoder,
                                        context_image_encoder, context_encoder,
                                        texts, text_lengths, images)
            # (batch_size, context_vector_size)

            intent_prob = intention(context)
            # (batch_size, utterance_type_size)

            loss = nll_loss(intent_prob, utter_types)
            sum_loss += loss

            eqs = torch.eq(torch.argmax(intent_prob, dim=1), utter_types)
            accuracy = torch.sum(eqs).item() * 1.0 / eqs.size(0)
            sum_accuracy += accuracy

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    intention.train()

    return sum_loss / num_batches, sum_accuracy / num_batches
Exemple #3
0
def knowledge_celebrity_test(context_text_encoder: TextEncoder,
                             context_image_encoder: ImageEncoder,
                             context_encoder: ContextEncoder,
                             to_hidden: ToHidden, celebrity_memory: Memory,
                             text_decoder: TextDecoder, test_dataset: Dataset,
                             celebrity_scores, text_length: int,
                             vocab: Dict[str, int]):
    """Knowledge celebrity test.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        to_hidden (ToHidden): Context to hidden.
        celebrity_memory (Memory): Celebrity Memory.
        text_decoder (TextDecoder): Text decoder.
        test_dataset (Dataset): Valid dataset.
        celebrity_scores: Celebrity scores.
        text_length (int): Text length.
        vocab (Dict[str, int]): Vocabulary.

    """

    id2word: List[str] = [None] * len(vocab)
    for word, wid in vocab.items():
        id2word[wid] = word

    # Test dataset loader.
    test_data_loader = DataLoader(
        test_dataset,
        batch_size=KnowledgeCelebrityTestConfig.batch_size,
        num_workers=KnowledgeCelebrityTestConfig.num_data_loader_workers)

    sum_loss = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    to_hidden.eval()
    celebrity_memory.eval()
    text_decoder.eval()

    output_file = open('knowledge_celebrity.out', 'w')

    with torch.no_grad():
        for batch_id, test_data in enumerate(test_data_loader):
            texts, text_lengths, images, utter_types = test_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, hiddens = encode_context(context_text_encoder,
                                              context_image_encoder,
                                              context_encoder, texts,
                                              text_lengths, images)
            # (batch_size, context_vector_size)

            knowledge_entry = celebrity_scores
            encode_knowledge_func = partial(celebrity_memory, knowledge_entry)

            text_eval(to_hidden,
                      text_decoder,
                      text_length,
                      id2word,
                      context,
                      texts[-1],
                      hiddens,
                      encode_knowledge_func,
                      output_file=output_file)

    output_file.close()
def knowledge_celebrity_valid(
        context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder,
        context_encoder: ContextEncoder, to_hidden: ToHidden,
        celebrity_memory: Memory, text_decoder: TextDecoder,
        valid_dataset: Dataset, celebrity_scores, text_length: int):
    """Knowledge celebrity valid.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        to_hidden (ToHidden): Context to hidden.
        celebrity_memory (Memory): Celebrity Memory.
        text_decoder (TextDecoder): Text decoder.
        valid_dataset (Dataset): Valid dataset.
        celebrity_scores: Celebrity scores.
        text_length (int): Text length.

    """

    # Valid dataset loader.
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=KnowledgeCelebrityValidConfig.batch_size,
        shuffle=True,
        num_workers=KnowledgeCelebrityValidConfig.num_data_loader_workers)

    sum_loss = 0
    num_batches = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    to_hidden.eval()
    celebrity_memory.eval()
    text_decoder.eval()

    with torch.no_grad():
        for batch_id, valid_data in enumerate(valid_data_loader):
            # Only valid `ValidConfig.num_batches` batches.
            if batch_id >= KnowledgeCelebrityValidConfig.num_batches:
                break

            num_batches += 1

            texts, text_lengths, images, utter_types = valid_data
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            # utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, hiddens = encode_context(context_text_encoder,
                                              context_image_encoder,
                                              context_encoder, texts,
                                              text_lengths, images)
            # (batch_size, context_vector_size)

            knowledge_entry = celebrity_scores
            encode_knowledge_func = partial(celebrity_memory, knowledge_entry)

            loss, n_totals = text_loss(to_hidden, text_decoder, text_length,
                                       context, texts[-1], text_lengths[-1],
                                       hiddens, encode_knowledge_func)
            sum_loss += loss / text_length

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    to_hidden.train()
    celebrity_memory.train()
    text_decoder.train()

    return sum_loss / num_batches
Exemple #5
0
def recommend_valid(
        context_text_encoder: TextEncoder,
        context_image_encoder: ImageEncoder,
        context_encoder: ContextEncoder,
        similarity: Similarity,
        valid_dataset: Dataset):
    """Recommend valid.

    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        similarity (Similarity): Intention.
        valid_dataset (Dataset): Valid dataset.

    """

    # Valid dataset loader.
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=RecommendValidConfig.batch_size,
        shuffle=True,
        num_workers=RecommendValidConfig.num_data_loader_workers
    )

    sum_loss = 0
    num_batches = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    # similarity.eval()
    # There might be a bug in the implement of resnet.

    num_ranks = torch.zeros(DatasetConfig.neg_images_max_num + 1,
                            dtype=torch.long)
    num_ranks = num_ranks.to(GlobalConfig.device)
    total_samples = 0

    with torch.no_grad():
        for batch_id, valid_data in enumerate(valid_data_loader):
            # Only valid `ValidConfig.num_batches` batches.
            if batch_id >= RecommendValidConfig.num_batches:
                break
            num_batches += 1

            context_dialog, pos_products, neg_products = valid_data

            texts, text_lengths, images, utter_types = context_dialog
            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            batch_size = texts.size(0)

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)
            # utter_types = utter_types.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, _ = encode_context(
                context_text_encoder,
                context_image_encoder,
                context_encoder,
                texts,
                text_lengths,
                images
            )
            # (batch_size, context_vector_size)

            loss = recommend_loss(similarity, batch_size, context,
                                  pos_products, neg_products)
            sum_loss += loss

            num_rank = recommend_eval(
                similarity,
                batch_size,
                context,
                pos_products,
                neg_products
            )

            total_samples += batch_size

            num_ranks += num_rank

    for i in range(DatasetConfig.neg_images_max_num):
        print('total recall@{} = {}'.format(
            i + 1,
            torch.sum(num_ranks[:i + 1]).item() / total_samples))

    # Switch to train mode.
    context_text_encoder.train()
    context_image_encoder.train()
    context_encoder.train()
    similarity.train()

    return sum_loss / num_batches
def knowledge_test(
        context_text_encoder: TextEncoder,
        context_image_encoder: ImageEncoder,
        context_encoder: ContextEncoder,
        to_hidden: ToHidden,
        attribute_kv_memory: KVMemory,
        text_decoder: TextDecoder,
        test_dataset: Dataset,
        text_length: int,
        vocab: Dict[str, int]):
    """Knowledge attribute test.
    Args:
        context_text_encoder (TextEncoder): Context text encoder.
        context_image_encoder (ImageEncoder): Context image encoder.
        context_encoder (ContextEncoder): Context encoder.
        to_hidden (ToHidden): Context to hidden.
        attribute_kv_memory (KVMemory): Attribute Key-Value Memory.
        text_decoder (TextDecoder): Text decoder.
        test_dataset (Dataset): Valid dataset.
        text_length (int): Text length.
        vocab (Dict[str, int]): Vocabulary.
    """

    id2word: List[str] = [None] * len(vocab)
    for word, wid in vocab.items():
        id2word[wid] = word

    # Test dataset loader.
    test_data_loader = DataLoader(
        test_dataset,
        batch_size=KnowledgeAttributeTestConfig.batch_size,
        num_workers=KnowledgeAttributeTestConfig.num_data_loader_workers
    )

    sum_loss = 0
    num_batches = 0

    # Switch to eval mode.
    context_text_encoder.eval()
    context_image_encoder.eval()
    context_encoder.eval()
    to_hidden.eval()
    attribute_kv_memory.eval()
    text_decoder.eval()

    output_file = open('knowledge_attribute.out', 'w')

    with torch.no_grad():
        for batch_id, test_data in enumerate(test_data_loader):
            num_batches += 1

            test_data, products = test_data
            keys, values, pair_length = products
            keys = keys.to(GlobalConfig.device)
            values = values.to(GlobalConfig.device)
            pair_length = pair_length.to(GlobalConfig.device)

            texts, text_lengths, images, utter_types = test_data

            # Sizes:
            # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len)
            # text_lengths: (batch_size, dialog_context_size + 1)
            # images: (batch_size, dialog_context_size + 1,
            #          pos_images_max_num, 3, image_size, image_size)
            # utter_types: (batch_size, )

            # To device.
            texts = texts.to(GlobalConfig.device)
            text_lengths = text_lengths.to(GlobalConfig.device)
            images = images.to(GlobalConfig.device)

            texts.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size, dialog_text_max_len)

            text_lengths.transpose_(0, 1)
            # (dialog_context_size + 1, batch_size)

            images.transpose_(0, 1)
            images.transpose_(1, 2)
            # (dialog_context_size + 1, pos_images_max_num, batch_size, 3,
            #  image_size, image_size)

            # Encode context.
            context, hiddens = encode_context(
                context_text_encoder,
                context_image_encoder,
                context_encoder,
                texts,
                text_lengths,
                images
            )
            # (batch_size, context_vector_size)

            encode_knowledge_func = partial(attribute_kv_memory,
                                            keys, values, pair_length)

            text_eval(to_hidden, text_decoder, text_length, id2word,
                      context, texts[-1], hiddens,
                      encode_knowledge_func, output_file=output_file)

    output_file.close()
Exemple #7
0
    """ 7) Create Model """
    VOCAB_SIZE = len(vocab.word2index)
    IMAGE_EMB_DIM = 256
    WORD_EMB_DIM = 256
    HIDDEN_DIM = 512
    word_embedding = torch.nn.Embedding(
        num_embeddings=VOCAB_SIZE,
        embedding_dim=WORD_EMB_DIM,
    )
    image_encoder = ImageEncoder(out_dim=IMAGE_EMB_DIM)
    image_decoder = CaptionRNN(num_classes=VOCAB_SIZE,
                               word_emb_dim=WORD_EMB_DIM,
                               img_emb_dim=IMAGE_EMB_DIM,
                               hidden_dim=HIDDEN_DIM)
    word_embedding.eval()
    image_encoder.eval()
    image_decoder.eval()
    """ 9) Load Weights """
    LOAD_WEIGHTS = True
    EMBEDDING_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-embedding-epoch-3.pt'
    ENCODER_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-encoder-epoch-3.pt'
    DECODER_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-decoder-epoch-3.pt'
    if LOAD_WEIGHTS:
        print("Loading pretrained weights...")
        word_embedding.load_state_dict(torch.load(EMBEDDING_WEIGHT_FILE))
        image_encoder.load_state_dict(torch.load(ENCODER_WEIGHT_FILE))
        image_decoder.load_state_dict(torch.load(DECODER_WEIGHT_FILE))
    """ 10) Device Setup"""
    device = 'cuda:1'
    device = torch.device(device)
def valid(epoch=1,
          checkpoints_dir='./checkpoints',
          use_bert=False,
          data_path=None,
          out_path='../prediction_result/valid_pred.json',
          output_ndcg=True):
    print("valid epoch{}".format(epoch))
    if data_path is not None:
        kdd_dataset = ValidDataset(data_path, use_bert=use_bert)
    else:
        kdd_dataset = ValidDataset(data_path, use_bert=use_bert)
    loader = DataLoader(kdd_dataset,
                        collate_fn=collate_fn_valid,
                        batch_size=128,
                        shuffle=False,
                        num_workers=8)
    tbar = tqdm(loader)
    text_encoder = TextEncoder(kdd_dataset.unknown_token + 1,
                               1024,
                               256,
                               use_bert=use_bert).cuda()
    image_encoder = ImageEncoder(input_dim=2048, output_dim=1024,
                                 nhead=4).cuda()
    score_model = ScoreModel(1024, 256).cuda()
    # category_embedding = model.CategoryEmbedding(768).cuda()
    checkpoints = torch.load(
        os.path.join(checkpoints_dir, 'model-epoch{}.pth'.format(epoch)))
    text_encoder.load_state_dict(checkpoints['query'])
    image_encoder.load_state_dict(checkpoints['item'])
    score_model.load_state_dict(checkpoints['score'])
    # score_model.load_state_dict(checkpoints['score'])
    outputs = {}
    image_encoder.eval()
    text_encoder.eval()
    score_model.eval()
    for query_id, product_id, query, query_len, features, boxes, category, obj_len in tbar:
        query, query_len = query.cuda(), query_len.cuda()
        query, hidden = text_encoder(query, query_len)
        features, boxes, obj_len = features.cuda(), boxes.cuda(), obj_len.cuda(
        )
        features = image_encoder(features, boxes, obj_len)
        score = score_model(query, hidden, query_len, features)
        score = score.data.cpu().numpy()

        # print(score2)

        for q_id, p_id, s in zip(query_id.data.numpy(),
                                 product_id.data.numpy(), score):
            outputs.setdefault(str(q_id), [])
            outputs[str(q_id)].append((p_id, s))

    for k, v in outputs.items():
        v = sorted(v, key=lambda x: x[1], reverse=True)
        v = [(str(x[0]), float(x[1])) for x in v]
        outputs[k] = v

    with open(out_path, 'w') as f:
        json.dump(outputs, f)
    if output_ndcg:
        pred = read_json(out_path)
        gt = read_json('../data/valid/valid_answer.json')
        score = 0
        k = 5
        for key, val in gt.items():
            ground_truth_ids = [str(x) for x in val]
            predictions = [x[0] for x in pred[key][:k]]
            ref_vec = [1.0] * len(ground_truth_ids)

            pred_vec = [
                1.0 if pid in ground_truth_ids else 0.0 for pid in predictions
            ]
            score += get_ndcg(pred_vec, ref_vec, k)
            # print(key)
            # print([pid for pid in predictions if pid not in ground_truth_ids])
            # print('========')
            # score += len(set(predictions).intersection(ground_truth_ids)) / len(ground_truth_ids)
        score = score / len(gt)
        print('ndcg@%d: %.4f' % (k, score))
        return score
    else:
        return None