def intention_test(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, intention: Intention, test_dataset: Dataset): """Intention test. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. intention (Intention): Intention. test_dataset (Dataset): Test dataset. """ # Test dataset loader. test_data_loader = DataLoader( test_dataset, batch_size=IntentionTestConfig.batch_size, shuffle=False, num_workers=IntentionTestConfig.num_data_loader_workers) sum_accuracy = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() intention.eval() with torch.no_grad(): for batch_id, valid_data in enumerate(test_data_loader): texts, text_lengths, images, utter_types = valid_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) intent_prob = intention(context) # (batch_size, utterance_type_size) intentions = torch.argmax(intent_prob, dim=1) eqs = torch.eq(intentions, utter_types) num_correct = torch.sum(eqs).item() accuracy = num_correct * 1.0 / eqs.size(0) sum_accuracy += accuracy # Print. print('pred:', intentions) print('true:', utter_types) print('# correct:', num_correct) print('accuracy:', accuracy) print('total accuracy:', sum_accuracy / (batch_id + 1))
def intention_valid(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, intention: Intention, valid_dataset: Dataset): """Intention valid. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. intention (Intention): Intention. valid_dataset (Dataset): Valid dataset. """ # Valid dataset loader. valid_data_loader = DataLoader( valid_dataset, batch_size=IntentionValidConfig.batch_size, shuffle=True, num_workers=IntentionValidConfig.num_data_loader_workers) sum_loss = 0 sum_accuracy = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() intention.eval() with torch.no_grad(): for batch_id, valid_data in enumerate(valid_data_loader): # Only valid `ValidConfig.num_batches` batches. if batch_id >= IntentionValidConfig.num_batches: break num_batches += 1 texts, text_lengths, images, utter_types = valid_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) intent_prob = intention(context) # (batch_size, utterance_type_size) loss = nll_loss(intent_prob, utter_types) sum_loss += loss eqs = torch.eq(torch.argmax(intent_prob, dim=1), utter_types) accuracy = torch.sum(eqs).item() * 1.0 / eqs.size(0) sum_accuracy += accuracy # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() intention.train() return sum_loss / num_batches, sum_accuracy / num_batches
def knowledge_celebrity_test(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, to_hidden: ToHidden, celebrity_memory: Memory, text_decoder: TextDecoder, test_dataset: Dataset, celebrity_scores, text_length: int, vocab: Dict[str, int]): """Knowledge celebrity test. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. to_hidden (ToHidden): Context to hidden. celebrity_memory (Memory): Celebrity Memory. text_decoder (TextDecoder): Text decoder. test_dataset (Dataset): Valid dataset. celebrity_scores: Celebrity scores. text_length (int): Text length. vocab (Dict[str, int]): Vocabulary. """ id2word: List[str] = [None] * len(vocab) for word, wid in vocab.items(): id2word[wid] = word # Test dataset loader. test_data_loader = DataLoader( test_dataset, batch_size=KnowledgeCelebrityTestConfig.batch_size, num_workers=KnowledgeCelebrityTestConfig.num_data_loader_workers) sum_loss = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() to_hidden.eval() celebrity_memory.eval() text_decoder.eval() output_file = open('knowledge_celebrity.out', 'w') with torch.no_grad(): for batch_id, test_data in enumerate(test_data_loader): texts, text_lengths, images, utter_types = test_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) knowledge_entry = celebrity_scores encode_knowledge_func = partial(celebrity_memory, knowledge_entry) text_eval(to_hidden, text_decoder, text_length, id2word, context, texts[-1], hiddens, encode_knowledge_func, output_file=output_file) output_file.close()
def knowledge_celebrity_valid( context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, to_hidden: ToHidden, celebrity_memory: Memory, text_decoder: TextDecoder, valid_dataset: Dataset, celebrity_scores, text_length: int): """Knowledge celebrity valid. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. to_hidden (ToHidden): Context to hidden. celebrity_memory (Memory): Celebrity Memory. text_decoder (TextDecoder): Text decoder. valid_dataset (Dataset): Valid dataset. celebrity_scores: Celebrity scores. text_length (int): Text length. """ # Valid dataset loader. valid_data_loader = DataLoader( valid_dataset, batch_size=KnowledgeCelebrityValidConfig.batch_size, shuffle=True, num_workers=KnowledgeCelebrityValidConfig.num_data_loader_workers) sum_loss = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() to_hidden.eval() celebrity_memory.eval() text_decoder.eval() with torch.no_grad(): for batch_id, valid_data in enumerate(valid_data_loader): # Only valid `ValidConfig.num_batches` batches. if batch_id >= KnowledgeCelebrityValidConfig.num_batches: break num_batches += 1 texts, text_lengths, images, utter_types = valid_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) # utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) knowledge_entry = celebrity_scores encode_knowledge_func = partial(celebrity_memory, knowledge_entry) loss, n_totals = text_loss(to_hidden, text_decoder, text_length, context, texts[-1], text_lengths[-1], hiddens, encode_knowledge_func) sum_loss += loss / text_length # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() to_hidden.train() celebrity_memory.train() text_decoder.train() return sum_loss / num_batches
def recommend_valid( context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, similarity: Similarity, valid_dataset: Dataset): """Recommend valid. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. similarity (Similarity): Intention. valid_dataset (Dataset): Valid dataset. """ # Valid dataset loader. valid_data_loader = DataLoader( valid_dataset, batch_size=RecommendValidConfig.batch_size, shuffle=True, num_workers=RecommendValidConfig.num_data_loader_workers ) sum_loss = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() # similarity.eval() # There might be a bug in the implement of resnet. num_ranks = torch.zeros(DatasetConfig.neg_images_max_num + 1, dtype=torch.long) num_ranks = num_ranks.to(GlobalConfig.device) total_samples = 0 with torch.no_grad(): for batch_id, valid_data in enumerate(valid_data_loader): # Only valid `ValidConfig.num_batches` batches. if batch_id >= RecommendValidConfig.num_batches: break num_batches += 1 context_dialog, pos_products, neg_products = valid_data texts, text_lengths, images, utter_types = context_dialog # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) batch_size = texts.size(0) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) # utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context( context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images ) # (batch_size, context_vector_size) loss = recommend_loss(similarity, batch_size, context, pos_products, neg_products) sum_loss += loss num_rank = recommend_eval( similarity, batch_size, context, pos_products, neg_products ) total_samples += batch_size num_ranks += num_rank for i in range(DatasetConfig.neg_images_max_num): print('total recall@{} = {}'.format( i + 1, torch.sum(num_ranks[:i + 1]).item() / total_samples)) # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() similarity.train() return sum_loss / num_batches
def knowledge_test( context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, to_hidden: ToHidden, attribute_kv_memory: KVMemory, text_decoder: TextDecoder, test_dataset: Dataset, text_length: int, vocab: Dict[str, int]): """Knowledge attribute test. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. to_hidden (ToHidden): Context to hidden. attribute_kv_memory (KVMemory): Attribute Key-Value Memory. text_decoder (TextDecoder): Text decoder. test_dataset (Dataset): Valid dataset. text_length (int): Text length. vocab (Dict[str, int]): Vocabulary. """ id2word: List[str] = [None] * len(vocab) for word, wid in vocab.items(): id2word[wid] = word # Test dataset loader. test_data_loader = DataLoader( test_dataset, batch_size=KnowledgeAttributeTestConfig.batch_size, num_workers=KnowledgeAttributeTestConfig.num_data_loader_workers ) sum_loss = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() to_hidden.eval() attribute_kv_memory.eval() text_decoder.eval() output_file = open('knowledge_attribute.out', 'w') with torch.no_grad(): for batch_id, test_data in enumerate(test_data_loader): num_batches += 1 test_data, products = test_data keys, values, pair_length = products keys = keys.to(GlobalConfig.device) values = values.to(GlobalConfig.device) pair_length = pair_length.to(GlobalConfig.device) texts, text_lengths, images, utter_types = test_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context( context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images ) # (batch_size, context_vector_size) encode_knowledge_func = partial(attribute_kv_memory, keys, values, pair_length) text_eval(to_hidden, text_decoder, text_length, id2word, context, texts[-1], hiddens, encode_knowledge_func, output_file=output_file) output_file.close()
""" 7) Create Model """ VOCAB_SIZE = len(vocab.word2index) IMAGE_EMB_DIM = 256 WORD_EMB_DIM = 256 HIDDEN_DIM = 512 word_embedding = torch.nn.Embedding( num_embeddings=VOCAB_SIZE, embedding_dim=WORD_EMB_DIM, ) image_encoder = ImageEncoder(out_dim=IMAGE_EMB_DIM) image_decoder = CaptionRNN(num_classes=VOCAB_SIZE, word_emb_dim=WORD_EMB_DIM, img_emb_dim=IMAGE_EMB_DIM, hidden_dim=HIDDEN_DIM) word_embedding.eval() image_encoder.eval() image_decoder.eval() """ 9) Load Weights """ LOAD_WEIGHTS = True EMBEDDING_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-embedding-epoch-3.pt' ENCODER_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-encoder-epoch-3.pt' DECODER_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-decoder-epoch-3.pt' if LOAD_WEIGHTS: print("Loading pretrained weights...") word_embedding.load_state_dict(torch.load(EMBEDDING_WEIGHT_FILE)) image_encoder.load_state_dict(torch.load(ENCODER_WEIGHT_FILE)) image_decoder.load_state_dict(torch.load(DECODER_WEIGHT_FILE)) """ 10) Device Setup""" device = 'cuda:1' device = torch.device(device)
def valid(epoch=1, checkpoints_dir='./checkpoints', use_bert=False, data_path=None, out_path='../prediction_result/valid_pred.json', output_ndcg=True): print("valid epoch{}".format(epoch)) if data_path is not None: kdd_dataset = ValidDataset(data_path, use_bert=use_bert) else: kdd_dataset = ValidDataset(data_path, use_bert=use_bert) loader = DataLoader(kdd_dataset, collate_fn=collate_fn_valid, batch_size=128, shuffle=False, num_workers=8) tbar = tqdm(loader) text_encoder = TextEncoder(kdd_dataset.unknown_token + 1, 1024, 256, use_bert=use_bert).cuda() image_encoder = ImageEncoder(input_dim=2048, output_dim=1024, nhead=4).cuda() score_model = ScoreModel(1024, 256).cuda() # category_embedding = model.CategoryEmbedding(768).cuda() checkpoints = torch.load( os.path.join(checkpoints_dir, 'model-epoch{}.pth'.format(epoch))) text_encoder.load_state_dict(checkpoints['query']) image_encoder.load_state_dict(checkpoints['item']) score_model.load_state_dict(checkpoints['score']) # score_model.load_state_dict(checkpoints['score']) outputs = {} image_encoder.eval() text_encoder.eval() score_model.eval() for query_id, product_id, query, query_len, features, boxes, category, obj_len in tbar: query, query_len = query.cuda(), query_len.cuda() query, hidden = text_encoder(query, query_len) features, boxes, obj_len = features.cuda(), boxes.cuda(), obj_len.cuda( ) features = image_encoder(features, boxes, obj_len) score = score_model(query, hidden, query_len, features) score = score.data.cpu().numpy() # print(score2) for q_id, p_id, s in zip(query_id.data.numpy(), product_id.data.numpy(), score): outputs.setdefault(str(q_id), []) outputs[str(q_id)].append((p_id, s)) for k, v in outputs.items(): v = sorted(v, key=lambda x: x[1], reverse=True) v = [(str(x[0]), float(x[1])) for x in v] outputs[k] = v with open(out_path, 'w') as f: json.dump(outputs, f) if output_ndcg: pred = read_json(out_path) gt = read_json('../data/valid/valid_answer.json') score = 0 k = 5 for key, val in gt.items(): ground_truth_ids = [str(x) for x in val] predictions = [x[0] for x in pred[key][:k]] ref_vec = [1.0] * len(ground_truth_ids) pred_vec = [ 1.0 if pid in ground_truth_ids else 0.0 for pid in predictions ] score += get_ndcg(pred_vec, ref_vec, k) # print(key) # print([pid for pid in predictions if pid not in ground_truth_ids]) # print('========') # score += len(set(predictions).intersection(ground_truth_ids)) / len(ground_truth_ids) score = score / len(gt) print('ndcg@%d: %.4f' % (k, score)) return score else: return None