def __init__(self, config: SimilarityConfig): super(Similarity, self).__init__() self.text_encoder = TextEncoder(config.product_text_encoder_config) self.text_encoder = self.text_encoder.to(GlobalConfig.device) self.image_encoder = ImageEncoder(config.product_image_encoder_config) self.image_encoder = self.image_encoder.to(GlobalConfig.device) self.linear = nn.Linear(config.mm_size, config.context_vector_size) self.linear = self.linear.to(GlobalConfig.device)
class Similarity(nn.Module): def __init__(self, config: SimilarityConfig): super(Similarity, self).__init__() self.text_encoder = TextEncoder(config.product_text_encoder_config) self.text_encoder = self.text_encoder.to(GlobalConfig.device) self.image_encoder = ImageEncoder(config.product_image_encoder_config) self.image_encoder = self.image_encoder.to(GlobalConfig.device) self.linear = nn.Linear(config.mm_size, config.context_vector_size) self.linear = self.linear.to(GlobalConfig.device) def forward(self, context, text, text_length, image): """Forward. Args: context: Context (batch_size, ContextEncoderConfig.output_size). text: Product text (batch_size, product_text_max_len). text_length: Product text length (batch_size, ). image: Product image (batch_size, 3, image_size, image_size). Returns: """ batch_size = context.size(0) sos = SOS_ID * torch.ones(batch_size, dtype=torch.long).view(-1, 1).to( GlobalConfig.device) # (batch_size) # Concat SOS. text = torch.cat((sos, text), 1).to(GlobalConfig.device) # (batch_size, product_text_max_len) text_length += 1 # (batch_size, ) encoded_text, _ = self.text_encoder(text, text_length) # (batch_size, text_feat_size) encoded_image = self.image_encoder(image, encoded_text) # (batch_size, image_feat_size) mm = torch.cat((encoded_text, encoded_image), 1) mm = mm.to(GlobalConfig.device) mm = self.linear(mm) return cosine_similarity(context, mm)
def build_models(self): # ###################encoders######################################## # image_encoder = ImageEncoder(output_channels=cfg.hidden_dim) if cfg.text_encoder_path != '': img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder') print('Load image encoder from:', img_encoder_path) state_dict = torch.load(img_encoder_path, map_location='cpu') if 'model' in state_dict.keys(): image_encoder.load_state_dict(state_dict['model']) else: image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): # make image encoder grad on p.requires_grad = True # image_encoder.eval() epoch = 0 ################################################################### text_encoder = TextEncoder(bert_config = self.bert_config) if cfg.text_encoder_path != '': epoch = cfg.text_encoder_path[istart:iend] epoch = int(epoch) + 1 text_encoder_path = cfg.text_encoder_path print('Load text encoder from:', text_encoder_path) state_dict = torch.load(text_encoder_path, map_location='cpu') if 'model' in state_dict.keys(): text_encoder.load_state_dict(state_dict['model']) else: text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): # make text encoder grad on p.requires_grad = True # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() return [text_encoder, image_encoder, epoch]
def knowledge_attribute_train(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, train_dataset: Dataset, valid_dataset: Dataset, test_dataset: Dataset, model_file: str, attribute_data: AttributeData, vocab: Dict[str, int], embed_init=None): """Knowledge styletip train. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. train_dataset (Dataset): Train dataset. valid_dataset (Dataset): Valid dataset. test_dataset (Dataset): Test dataset. model_file (str): Saved model file. attribute_data (AttributeData): Attribute data. vocab (Dict[str, int]): Vocabulary. embed_init: Initial embedding (vocab_size, embed_size). """ # Data loader. train_data_loader = DataLoader( dataset=train_dataset, batch_size=KnowledgeAttributeTrainConfig.batch_size, shuffle=True, num_workers=KnowledgeAttributeTrainConfig.num_data_loader_workers) # Model. vocab_size = len(vocab) attribute_kv_memory_config = AttributeKVMemoryConfig( len(attribute_data.key_vocab), len(attribute_data.value_vocab)) text_decoder_config = KnowledgeTextDecoderConfig(vocab_size, MemoryConfig.memory_size, MemoryConfig.output_size, embed_init) to_hidden = ToHidden(text_decoder_config) to_hidden = to_hidden.to(GlobalConfig.device) attribute_kv_memory = KVMemory(attribute_kv_memory_config) attribute_kv_memory = attribute_kv_memory.to(GlobalConfig.device) text_decoder = TextDecoder(text_decoder_config) text_decoder = text_decoder.to(GlobalConfig.device) # Model parameters. params = list( chain.from_iterable([ list(model.parameters()) for model in [ context_text_encoder, context_image_encoder, context_encoder, to_hidden, attribute_kv_memory, text_decoder ] ])) optimizer = Adam(params, lr=KnowledgeAttributeTrainConfig.learning_rate) epoch_id = 0 min_valid_loss = None # Load saved state. if isfile(model_file): state = torch.load(model_file) to_hidden.load_state_dict(state['to_hidden']) attribute_kv_memory.load_state_dict(state['attribute_kv_memory']) text_decoder.load_state_dict(state['text_decoder']) optimizer.load_state_dict(state['optimizer']) epoch_id = state['epoch_id'] min_valid_loss = state['min_valid_loss'] # Loss. sum_loss = 0 bad_loss_cnt = 0 # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() to_hidden.train() attribute_kv_memory.train() text_decoder.train() finished = False for epoch_id in range(epoch_id, KnowledgeAttributeTrainConfig.num_iterations): for batch_id, train_data in enumerate(train_data_loader): # Set gradients to 0. optimizer.zero_grad() train_data, products = train_data keys, values, pair_length = products keys = keys.to(GlobalConfig.device) values = values.to(GlobalConfig.device) pair_length = pair_length.to(GlobalConfig.device) texts, text_lengths, images, utter_types = train_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) encode_knowledge_func = partial(attribute_kv_memory, keys, values, pair_length) loss, n_totals = text_loss(to_hidden, text_decoder, text_decoder_config.text_length, context, texts[-1], text_lengths[-1], hiddens, encode_knowledge_func) sum_loss += loss / text_decoder_config.text_length loss.backward() optimizer.step() # Print loss every `TrainConfig.print_freq` batches. if (batch_id + 1) % KnowledgeAttributeTrainConfig.print_freq == 0: cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") sum_loss /= KnowledgeAttributeTrainConfig.print_freq print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format( epoch_id + 1, batch_id + 1, sum_loss, cur_time)) sum_loss = 0 # Valid every `TrainConfig.valid_freq` batches. if (batch_id + 1) % KnowledgeAttributeTrainConfig.valid_freq == 0: valid_loss = knowledge_attribute_valid( context_text_encoder, context_image_encoder, context_encoder, to_hidden, attribute_kv_memory, text_decoder, valid_dataset, text_decoder_config.text_length) cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('valid_loss: {} \ttime: {}'.format(valid_loss, cur_time)) # Save current best model. if min_valid_loss is None or valid_loss < min_valid_loss: min_valid_loss = valid_loss bad_loss_cnt = 0 save_dict = { 'task': KNOWLEDGE_ATTRIBUTE_SUBTASK, 'epoch_id': epoch_id, 'min_valid_loss': min_valid_loss, 'optimizer': optimizer.state_dict(), 'context_text_encoder': context_text_encoder.state_dict(), 'context_image_encoder': context_image_encoder.state_dict(), 'context_encoder': context_encoder.state_dict(), 'to_hidden': to_hidden.state_dict(), 'attribute_kv_memory': attribute_kv_memory.state_dict(), 'text_decoder': text_decoder.state_dict() } torch.save(save_dict, model_file) print('Best model saved.') else: bad_loss_cnt += 1 if bad_loss_cnt > KnowledgeAttributeTrainConfig.patience: knowledge_attribute_test( context_text_encoder, context_image_encoder, context_encoder, to_hidden, attribute_kv_memory, text_decoder, test_dataset, text_decoder_config.text_length, vocab) finished = True break if finished: break
def intention_valid(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, intention: Intention, valid_dataset: Dataset): """Intention valid. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. intention (Intention): Intention. valid_dataset (Dataset): Valid dataset. """ # Valid dataset loader. valid_data_loader = DataLoader( valid_dataset, batch_size=IntentionValidConfig.batch_size, shuffle=True, num_workers=IntentionValidConfig.num_data_loader_workers) sum_loss = 0 sum_accuracy = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() intention.eval() with torch.no_grad(): for batch_id, valid_data in enumerate(valid_data_loader): # Only valid `ValidConfig.num_batches` batches. if batch_id >= IntentionValidConfig.num_batches: break num_batches += 1 texts, text_lengths, images, utter_types = valid_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) intent_prob = intention(context) # (batch_size, utterance_type_size) loss = nll_loss(intent_prob, utter_types) sum_loss += loss eqs = torch.eq(torch.argmax(intent_prob, dim=1), utter_types) accuracy = torch.sum(eqs).item() * 1.0 / eqs.size(0) sum_accuracy += accuracy # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() intention.train() return sum_loss / num_batches, sum_accuracy / num_batches
def knowledge_celebrity_valid( context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, to_hidden: ToHidden, celebrity_memory: Memory, text_decoder: TextDecoder, valid_dataset: Dataset, celebrity_scores, text_length: int): """Knowledge celebrity valid. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. to_hidden (ToHidden): Context to hidden. celebrity_memory (Memory): Celebrity Memory. text_decoder (TextDecoder): Text decoder. valid_dataset (Dataset): Valid dataset. celebrity_scores: Celebrity scores. text_length (int): Text length. """ # Valid dataset loader. valid_data_loader = DataLoader( valid_dataset, batch_size=KnowledgeCelebrityValidConfig.batch_size, shuffle=True, num_workers=KnowledgeCelebrityValidConfig.num_data_loader_workers) sum_loss = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() to_hidden.eval() celebrity_memory.eval() text_decoder.eval() with torch.no_grad(): for batch_id, valid_data in enumerate(valid_data_loader): # Only valid `ValidConfig.num_batches` batches. if batch_id >= KnowledgeCelebrityValidConfig.num_batches: break num_batches += 1 texts, text_lengths, images, utter_types = valid_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) # utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) knowledge_entry = celebrity_scores encode_knowledge_func = partial(celebrity_memory, knowledge_entry) loss, n_totals = text_loss(to_hidden, text_decoder, text_length, context, texts[-1], text_lengths[-1], hiddens, encode_knowledge_func) sum_loss += loss / text_length # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() to_hidden.train() celebrity_memory.train() text_decoder.train() return sum_loss / num_batches
def recommend_train(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, train_dataset: Dataset, valid_dataset: Dataset, test_dataset: Dataset, model_file: str, vocab_size: int, embed_init=None): """Recommend train. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. train_dataset (Dataset): Train dataset. valid_dataset (Dataset): Valid dataset. test_dataset (Dataset): Test dataset. model_file (str): Saved model file. vocab_size (int): Vocabulary size. embed_init: Initial embedding (vocab_size, embed_size). """ # Data loader. train_data_loader = DataLoader( dataset=train_dataset, batch_size=RecommendTrainConfig.batch_size, shuffle=True, num_workers=RecommendTrainConfig.num_data_loader_workers) # Model. similarity_config = SimilarityConfig(vocab_size, embed_init) similarity = Similarity(similarity_config).to(GlobalConfig.device) # Model parameters. params = list( chain.from_iterable([ list(model.parameters()) for model in [ context_text_encoder, context_image_encoder, context_encoder, similarity ] ])) optimizer = Adam(params, lr=RecommendTrainConfig.learning_rate) epoch_id = 0 min_valid_loss = None # Load saved state. if isfile(model_file): state = torch.load(model_file) similarity.load_state_dict(state['similarity']) optimizer.load_state_dict(state['optimizer']) epoch_id = state['epoch_id'] min_valid_loss = state['min_valid_loss'] # Loss. sum_loss = 0 bad_loss_cnt = 0 # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() similarity.train() finished = False for epoch_id in range(epoch_id, RecommendTrainConfig.num_iterations): for batch_id, train_data in enumerate(train_data_loader): # Sets gradients to 0. optimizer.zero_grad() context_dialog, pos_products, neg_products = train_data texts, text_lengths, images, utter_types = context_dialog # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) batch_size = texts.size(0) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) # utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) loss = recommend_loss(similarity, batch_size, context, pos_products, neg_products) sum_loss += loss loss.backward() optimizer.step() # Print loss every `TrainConfig.print_freq` batches. if (batch_id + 1) % RecommendTrainConfig.print_freq == 0: cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") sum_loss /= RecommendTrainConfig.print_freq print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format( epoch_id + 1, batch_id + 1, sum_loss, cur_time)) sum_loss = 0 # Valid every `TrainConfig.valid_freq` batches. if (batch_id + 1) % RecommendTrainConfig.valid_freq == 0: valid_loss = recommend_valid(context_text_encoder, context_image_encoder, context_encoder, similarity, valid_dataset) cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('valid_loss: {} \ttime: {}'.format(valid_loss, cur_time)) # Save current best model. if min_valid_loss is None or valid_loss < min_valid_loss: min_valid_loss = valid_loss bad_loss_cnt = 0 save_dict = { 'task': RECOMMEND_TASK, 'epoch_id': epoch_id, 'min_valid_loss': min_valid_loss, 'optimizer': optimizer.state_dict(), 'context_text_encoder': context_text_encoder.state_dict(), 'context_image_encoder': context_image_encoder.state_dict(), 'context_encoder': context_encoder.state_dict(), 'similarity': similarity.state_dict() } torch.save(save_dict, model_file) print('Best model saved.') else: bad_loss_cnt += 1 if bad_loss_cnt > RecommendTrainConfig.patience: recommend_test(context_text_encoder, context_image_encoder, context_encoder, similarity, test_dataset) finished = True break if finished: break
def train(task: int, model_file_name: str): """Train model. Args: task (int): Task. model_file_name (str): Model file name (saved or to be saved). """ # Check if data exists. if not isfile(DatasetConfig.common_raw_data_file): raise ValueError('No common raw data.') # Load extracted common data. common_data: CommonData = load_pkl(DatasetConfig.common_raw_data_file) # Dialog data files. train_dialog_data_file = DatasetConfig.get_dialog_filename( task, TRAIN_MODE) valid_dialog_data_file = DatasetConfig.get_dialog_filename( task, VALID_MODE) test_dialog_data_file = DatasetConfig.get_dialog_filename(task, TEST_MODE) if not isfile(train_dialog_data_file): raise ValueError('No train dialog data file.') if not isfile(valid_dialog_data_file): raise ValueError('No valid dialog data file.') # Load extracted dialogs. train_dialogs: List[TidyDialog] = load_pkl(train_dialog_data_file) valid_dialogs: List[TidyDialog] = load_pkl(valid_dialog_data_file) test_dialogs: List[TidyDialog] = load_pkl(test_dialog_data_file) if task in {KNOWLEDGE_TASK}: knowledge_data = KnowledgeData() # Dataset wrap. train_dataset = Dataset( task, common_data.dialog_vocab, None, #common_data.obj_id, train_dialogs, knowledge_data if task == KNOWLEDGE_TASK else None) valid_dataset = Dataset( task, common_data.dialog_vocab, None, #common_data.obj_id, valid_dialogs, knowledge_data if task == KNOWLEDGE_TASK else None) test_dataset = Dataset( task, common_data.dialog_vocab, None, #common_data.obj_id, test_dialogs, knowledge_data if task == KNOWLEDGE_TASK else None) print('Train dataset size:', len(train_dataset)) print('Valid dataset size:', len(valid_dataset)) print('Test dataset size:', len(test_dataset)) # Get initial embedding. vocab_size = len(common_data.dialog_vocab) embed_init = get_embed_init(common_data.glove, vocab_size).to(GlobalConfig.device) # Context model configurations. context_text_encoder_config = ContextTextEncoderConfig( vocab_size, embed_init) context_image_encoder_config = ContextImageEncoderConfig() context_encoder_config = ContextEncoderConfig() # Context models. context_text_encoder = TextEncoder(context_text_encoder_config) context_text_encoder = context_text_encoder.to(GlobalConfig.device) context_image_encoder = ImageEncoder(context_image_encoder_config) context_image_encoder = context_image_encoder.to(GlobalConfig.device) context_encoder = ContextEncoder(context_encoder_config) context_encoder = context_encoder.to(GlobalConfig.device) # Load model file. model_file = join(DatasetConfig.dump_dir, model_file_name) if isfile(model_file): state = torch.load(model_file) # if task != state['task']: # raise ValueError("Task doesn't match.") context_text_encoder.load_state_dict(state['context_text_encoder']) context_image_encoder.load_state_dict(state['context_image_encoder']) context_encoder.load_state_dict(state['context_encoder']) # Task-specific parts. if task == INTENTION_TASK: intention_train(context_text_encoder, context_image_encoder, context_encoder, train_dataset, valid_dataset, test_dataset, model_file) elif task == TEXT_TASK: text_train(context_text_encoder, context_image_encoder, context_encoder, train_dataset, valid_dataset, test_dataset, model_file, common_data.dialog_vocab, embed_init) elif task == RECOMMEND_TASK: recommend_train(context_text_encoder, context_image_encoder, context_encoder, train_dataset, valid_dataset, test_dataset, model_file, vocab_size, embed_init) elif task == KNOWLEDGE_TASK: knowledge_attribute_train(context_text_encoder, context_image_encoder, context_encoder, train_dataset, valid_dataset, test_dataset, model_file, knowledge_data.attribute_data, common_data.dialog_vocab, embed_init)
def recommend_valid( context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, similarity: Similarity, valid_dataset: Dataset): """Recommend valid. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. similarity (Similarity): Intention. valid_dataset (Dataset): Valid dataset. """ # Valid dataset loader. valid_data_loader = DataLoader( valid_dataset, batch_size=RecommendValidConfig.batch_size, shuffle=True, num_workers=RecommendValidConfig.num_data_loader_workers ) sum_loss = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() # similarity.eval() # There might be a bug in the implement of resnet. num_ranks = torch.zeros(DatasetConfig.neg_images_max_num + 1, dtype=torch.long) num_ranks = num_ranks.to(GlobalConfig.device) total_samples = 0 with torch.no_grad(): for batch_id, valid_data in enumerate(valid_data_loader): # Only valid `ValidConfig.num_batches` batches. if batch_id >= RecommendValidConfig.num_batches: break num_batches += 1 context_dialog, pos_products, neg_products = valid_data texts, text_lengths, images, utter_types = context_dialog # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) batch_size = texts.size(0) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) # utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context( context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images ) # (batch_size, context_vector_size) loss = recommend_loss(similarity, batch_size, context, pos_products, neg_products) sum_loss += loss num_rank = recommend_eval( similarity, batch_size, context, pos_products, neg_products ) total_samples += batch_size num_ranks += num_rank for i in range(DatasetConfig.neg_images_max_num): print('total recall@{} = {}'.format( i + 1, torch.sum(num_ranks[:i + 1]).item() / total_samples)) # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() similarity.train() return sum_loss / num_batches
def knowledge_test( context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, to_hidden: ToHidden, attribute_kv_memory: KVMemory, text_decoder: TextDecoder, test_dataset: Dataset, text_length: int, vocab: Dict[str, int]): """Knowledge attribute test. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. to_hidden (ToHidden): Context to hidden. attribute_kv_memory (KVMemory): Attribute Key-Value Memory. text_decoder (TextDecoder): Text decoder. test_dataset (Dataset): Valid dataset. text_length (int): Text length. vocab (Dict[str, int]): Vocabulary. """ id2word: List[str] = [None] * len(vocab) for word, wid in vocab.items(): id2word[wid] = word # Test dataset loader. test_data_loader = DataLoader( test_dataset, batch_size=KnowledgeAttributeTestConfig.batch_size, num_workers=KnowledgeAttributeTestConfig.num_data_loader_workers ) sum_loss = 0 num_batches = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() to_hidden.eval() attribute_kv_memory.eval() text_decoder.eval() output_file = open('knowledge_attribute.out', 'w') with torch.no_grad(): for batch_id, test_data in enumerate(test_data_loader): num_batches += 1 test_data, products = test_data keys, values, pair_length = products keys = keys.to(GlobalConfig.device) values = values.to(GlobalConfig.device) pair_length = pair_length.to(GlobalConfig.device) texts, text_lengths, images, utter_types = test_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context( context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images ) # (batch_size, context_vector_size) encode_knowledge_func = partial(attribute_kv_memory, keys, values, pair_length) text_eval(to_hidden, text_decoder, text_length, id2word, context, texts[-1], hiddens, encode_knowledge_func, output_file=output_file) output_file.close()
def load_network(self): image_generator = ImageGenerator() image_generator.apply(weights_init) disc_image = DiscriminatorImage() disc_image.apply(weights_init) emb_dim = 300 text_encoder = TextEncoder(emb_dim, self.txt_emb, 1, dropout=0.0) attn_model = 'general' text_generator = TextGenerator(attn_model, emb_dim, len(self.txt_dico.id2word), self.txt_emb, n_layers=1, dropout=0.0) image_encoder = ImageEncoder() image_encoder.apply(weights_init) disc_latent = DiscriminatorLatent(emb_dim) if cfg.NET_G != '': state_dict = \ torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load from: ', cfg.NET_G) if cfg.NET_D != '': state_dict = \ torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) netD.load_state_dict(state_dict) print('Load from: ', cfg.NET_D) if cfg.ENCODER != '': state_dict = \ torch.load(cfg.ENCODER, map_location=lambda storage, loc: storage) encoder.load_state_dict(state_dict) print('Load from: ', cfg.ENCODER) if cfg.DECODER != '': state_dict = \ torch.load(cfg.DECODER, map_location=lambda storage, loc: storage) decoder.load_state_dict(state_dict) print('Load from: ', cfg.DECODER) if cfg.IMAGE_ENCODER != '': state_dict = \ torch.load(cfg.IMAGE_ENCODER, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) print('Load from: ', cfg.IMAGE_ENCODER) if cfg.CUDA: image_encoder.cuda() image_generator.cuda() text_encoder.cuda() text_generator.cuda() disc_image.cuda() disc_latent.cuda() return image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent
def valid(epoch=1, checkpoints_dir='./checkpoints', use_bert=False, data_path=None, out_path='../prediction_result/valid_pred.json', output_ndcg=True): print("valid epoch{}".format(epoch)) if data_path is not None: kdd_dataset = ValidDataset(data_path, use_bert=use_bert) else: kdd_dataset = ValidDataset(data_path, use_bert=use_bert) loader = DataLoader(kdd_dataset, collate_fn=collate_fn_valid, batch_size=128, shuffle=False, num_workers=8) tbar = tqdm(loader) text_encoder = TextEncoder(kdd_dataset.unknown_token + 1, 1024, 256, use_bert=use_bert).cuda() image_encoder = ImageEncoder(input_dim=2048, output_dim=1024, nhead=4).cuda() score_model = ScoreModel(1024, 256).cuda() # category_embedding = model.CategoryEmbedding(768).cuda() checkpoints = torch.load( os.path.join(checkpoints_dir, 'model-epoch{}.pth'.format(epoch))) text_encoder.load_state_dict(checkpoints['query']) image_encoder.load_state_dict(checkpoints['item']) score_model.load_state_dict(checkpoints['score']) # score_model.load_state_dict(checkpoints['score']) outputs = {} image_encoder.eval() text_encoder.eval() score_model.eval() for query_id, product_id, query, query_len, features, boxes, category, obj_len in tbar: query, query_len = query.cuda(), query_len.cuda() query, hidden = text_encoder(query, query_len) features, boxes, obj_len = features.cuda(), boxes.cuda(), obj_len.cuda( ) features = image_encoder(features, boxes, obj_len) score = score_model(query, hidden, query_len, features) score = score.data.cpu().numpy() # print(score2) for q_id, p_id, s in zip(query_id.data.numpy(), product_id.data.numpy(), score): outputs.setdefault(str(q_id), []) outputs[str(q_id)].append((p_id, s)) for k, v in outputs.items(): v = sorted(v, key=lambda x: x[1], reverse=True) v = [(str(x[0]), float(x[1])) for x in v] outputs[k] = v with open(out_path, 'w') as f: json.dump(outputs, f) if output_ndcg: pred = read_json(out_path) gt = read_json('../data/valid/valid_answer.json') score = 0 k = 5 for key, val in gt.items(): ground_truth_ids = [str(x) for x in val] predictions = [x[0] for x in pred[key][:k]] ref_vec = [1.0] * len(ground_truth_ids) pred_vec = [ 1.0 if pid in ground_truth_ids else 0.0 for pid in predictions ] score += get_ndcg(pred_vec, ref_vec, k) # print(key) # print([pid for pid in predictions if pid not in ground_truth_ids]) # print('========') # score += len(set(predictions).intersection(ground_truth_ids)) / len(ground_truth_ids) score = score / len(gt) print('ndcg@%d: %.4f' % (k, score)) return score else: return None
val_dataloader_mohx = DataLoader(dataset=val_dataset_mohx, batch_size=batch_size, shuffle=False, collate_fn=TextDataset.collate_fn) """ 3. Model training """ ''' 3. 1 set up model, loss criterion, optimizer ''' # Instantiate the model Exp_model = TextEncoder(embedding_dim=1024, hidden_size=256, num_layers=1, bidir=True, dropout1=0.5) Query_model = TextEncoder(embedding_dim=1024, hidden_size=256, num_layers=1, bidir=True, dropout1=0.5) Attn_model = AttentionModel(para_encoder_input_dim=512, query_dim=512, output_dim=256) para_encoder_attn_model = AttentionModel(para_encoder_input_dim=512, query_dim=512, output_dim=512) para_encoder = ParaEncoder(input_dim=1024,
def knowledge_celebrity_test(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, to_hidden: ToHidden, celebrity_memory: Memory, text_decoder: TextDecoder, test_dataset: Dataset, celebrity_scores, text_length: int, vocab: Dict[str, int]): """Knowledge celebrity test. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. to_hidden (ToHidden): Context to hidden. celebrity_memory (Memory): Celebrity Memory. text_decoder (TextDecoder): Text decoder. test_dataset (Dataset): Valid dataset. celebrity_scores: Celebrity scores. text_length (int): Text length. vocab (Dict[str, int]): Vocabulary. """ id2word: List[str] = [None] * len(vocab) for word, wid in vocab.items(): id2word[wid] = word # Test dataset loader. test_data_loader = DataLoader( test_dataset, batch_size=KnowledgeCelebrityTestConfig.batch_size, num_workers=KnowledgeCelebrityTestConfig.num_data_loader_workers) sum_loss = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() to_hidden.eval() celebrity_memory.eval() text_decoder.eval() output_file = open('knowledge_celebrity.out', 'w') with torch.no_grad(): for batch_id, test_data in enumerate(test_data_loader): texts, text_lengths, images, utter_types = test_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) knowledge_entry = celebrity_scores encode_knowledge_func = partial(celebrity_memory, knowledge_entry) text_eval(to_hidden, text_decoder, text_length, id2word, context, texts[-1], hiddens, encode_knowledge_func, output_file=output_file) output_file.close()
device = torch.device("cuda", local_rank) checkpoints_dir = './checkpoints' start_epoch = 0 use_bert = True if not os.path.exists(checkpoints_dir) and local_rank == 0: os.makedirs(checkpoints_dir) kdd_dataset = Dataset(use_bert=use_bert) sampler = DistributedSampler(kdd_dataset) loader = DataLoader(kdd_dataset, collate_fn=collate_fn, batch_size=130, sampler=sampler, num_workers=15) nhead = 4 text_encoder = TextEncoder(kdd_dataset.unknown_token + 1, 1024, 256, use_bert=use_bert).cuda() image_encoder = ImageEncoder(input_dim=2048, output_dim=1024, nhead=nhead) image_encoder.load_pretrained_weights( path='../user_data/image_encoder_large.pth') image_encoder = image_encoder.cuda() score_model = ScoreModel(1024, 256).cuda() # text_generator = TextGenerator(text_encoder.embed.num_embeddings).cuda() # score_model = ScoreModel(30522, 256, num_heads=1).cuda() # category_embedding = CategoryEmbedding(256).cuda() optimizer = Adam(image_encoder.get_params() + text_encoder.get_params() + score_model.get_params()) if start_epoch > 0 and local_rank == 0: checkpoints = torch.load(
def intention_test(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, intention: Intention, test_dataset: Dataset): """Intention test. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. intention (Intention): Intention. test_dataset (Dataset): Test dataset. """ # Test dataset loader. test_data_loader = DataLoader( test_dataset, batch_size=IntentionTestConfig.batch_size, shuffle=False, num_workers=IntentionTestConfig.num_data_loader_workers) sum_accuracy = 0 # Switch to eval mode. context_text_encoder.eval() context_image_encoder.eval() context_encoder.eval() intention.eval() with torch.no_grad(): for batch_id, valid_data in enumerate(test_data_loader): texts, text_lengths, images, utter_types = valid_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) intent_prob = intention(context) # (batch_size, utterance_type_size) intentions = torch.argmax(intent_prob, dim=1) eqs = torch.eq(intentions, utter_types) num_correct = torch.sum(eqs).item() accuracy = num_correct * 1.0 / eqs.size(0) sum_accuracy += accuracy # Print. print('pred:', intentions) print('true:', utter_types) print('# correct:', num_correct) print('accuracy:', accuracy) print('total accuracy:', sum_accuracy / (batch_id + 1))
def intention_train(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, train_dataset: Dataset, valid_dataset: Dataset, test_dataset: Dataset, model_file: str): """Intention train. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. train_dataset (Dataset): Train dataset. valid_dataset (Dataset): Valid dataset. test_dataset (Dataset): Test dataset. model_file (str): Saved model file. """ # Data loader. train_data_loader = DataLoader( dataset=train_dataset, batch_size=IntentionTrainConfig.batch_size, shuffle=True, num_workers=IntentionTrainConfig.num_data_loader_workers) # Model. intention_config = IntentionConfig() intention = Intention(intention_config).to(GlobalConfig.device) # Model parameters. params = list( chain.from_iterable([ list(model.parameters()) for model in [ context_text_encoder, context_image_encoder, context_encoder, intention ] ])) optimizer = Adam(params, lr=IntentionTrainConfig.learning_rate) epoch_id = 0 min_valid_loss = None # Load saved state. if isfile(model_file): state = torch.load(model_file) intention.load_state_dict(state['intention']) optimizer.load_state_dict(state['optimizer']) epoch_id = state['epoch_id'] min_valid_loss = state['min_valid_loss'] # Loss. sum_loss = 0 bad_loss_cnt = 0 # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() intention.train() finished = False for epoch_id in range(epoch_id, IntentionTrainConfig.num_iterations): for batch_id, train_data in enumerate(train_data_loader): # Sets gradients to 0. optimizer.zero_grad() texts, text_lengths, images, utter_types = train_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) intent_prob = intention(context) # (batch_size, utterance_type_size) loss = nll_loss(intent_prob, utter_types) sum_loss += loss loss.backward() optimizer.step() # Print loss every `TrainConfig.print_freq` batches. if (batch_id + 1) % IntentionTrainConfig.print_freq == 0: cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") sum_loss /= IntentionTrainConfig.print_freq print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format( epoch_id + 1, batch_id + 1, sum_loss, cur_time)) sum_loss = 0 # Valid every `TrainConfig.valid_freq` batches. if (batch_id + 1) % IntentionTrainConfig.valid_freq == 0: valid_loss, accuracy = intention_valid(context_text_encoder, context_image_encoder, context_encoder, intention, valid_dataset) cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('valid_loss: {} \taccuracy: {} \ttime: {}'.format( valid_loss, accuracy, cur_time)) # Save current best model. if min_valid_loss is None or valid_loss < min_valid_loss: min_valid_loss = valid_loss bad_loss_cnt = 0 save_dict = { 'task': INTENTION_TASK, 'epoch_id': epoch_id, 'min_valid_loss': min_valid_loss, 'optimizer': optimizer.state_dict(), 'context_text_encoder': context_text_encoder.state_dict(), 'context_image_encoder': context_image_encoder.state_dict(), 'context_encoder': context_encoder.state_dict(), 'intention': intention.state_dict() } torch.save(save_dict, model_file) print('Best model saved.') else: bad_loss_cnt += 1 if bad_loss_cnt > IntentionTrainConfig.patience: intention_test(context_text_encoder, context_image_encoder, context_encoder, intention, test_dataset) finished = True break if finished: break