コード例 #1
0
    def __init__(self, config: SimilarityConfig):
        super(Similarity, self).__init__()

        self.text_encoder = TextEncoder(config.product_text_encoder_config)
        self.text_encoder = self.text_encoder.to(GlobalConfig.device)

        self.image_encoder = ImageEncoder(config.product_image_encoder_config)
        self.image_encoder = self.image_encoder.to(GlobalConfig.device)

        self.linear = nn.Linear(config.mm_size, config.context_vector_size)
        self.linear = self.linear.to(GlobalConfig.device)
コード例 #2
0
ファイル: trainer_v2.py プロジェクト: mshaikh2/IPMI2021
    def build_models(self):
        # ###################encoders######################################## #
      
        image_encoder = ImageEncoder(output_channels=cfg.hidden_dim)
        if cfg.text_encoder_path != '':
            img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder')
            print('Load image encoder from:', img_encoder_path)
            state_dict = torch.load(img_encoder_path, map_location='cpu')
            if 'model' in state_dict.keys():
                image_encoder.load_state_dict(state_dict['model'])
            else:
                image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters(): # make image encoder grad on
            p.requires_grad = True
   
        
#         image_encoder.eval()
        epoch = 0
        
        ###################################################################
        text_encoder = TextEncoder(bert_config = self.bert_config)
        if cfg.text_encoder_path != '':
            epoch = cfg.text_encoder_path[istart:iend]
            epoch = int(epoch) + 1
            text_encoder_path = cfg.text_encoder_path
            print('Load text encoder from:', text_encoder_path)
            state_dict = torch.load(text_encoder_path, map_location='cpu')
            if 'model' in state_dict.keys():
                text_encoder.load_state_dict(state_dict['model'])
            else:
                text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters(): # make text encoder grad on
            p.requires_grad = True
           
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            
        return [text_encoder, image_encoder, epoch]
コード例 #3
0
def train(task: int, model_file_name: str):
    """Train model.
    Args:
        task (int): Task.
        model_file_name (str): Model file name (saved or to be saved).
    """

    # Check if data exists.
    if not isfile(DatasetConfig.common_raw_data_file):
        raise ValueError('No common raw data.')

    # Load extracted common data.
    common_data: CommonData = load_pkl(DatasetConfig.common_raw_data_file)

    # Dialog data files.
    train_dialog_data_file = DatasetConfig.get_dialog_filename(
        task, TRAIN_MODE)
    valid_dialog_data_file = DatasetConfig.get_dialog_filename(
        task, VALID_MODE)
    test_dialog_data_file = DatasetConfig.get_dialog_filename(task, TEST_MODE)
    if not isfile(train_dialog_data_file):
        raise ValueError('No train dialog data file.')
    if not isfile(valid_dialog_data_file):
        raise ValueError('No valid dialog data file.')

    # Load extracted dialogs.
    train_dialogs: List[TidyDialog] = load_pkl(train_dialog_data_file)
    valid_dialogs: List[TidyDialog] = load_pkl(valid_dialog_data_file)
    test_dialogs: List[TidyDialog] = load_pkl(test_dialog_data_file)

    if task in {KNOWLEDGE_TASK}:
        knowledge_data = KnowledgeData()

    # Dataset wrap.
    train_dataset = Dataset(
        task,
        common_data.dialog_vocab,
        None,  #common_data.obj_id,
        train_dialogs,
        knowledge_data if task == KNOWLEDGE_TASK else None)
    valid_dataset = Dataset(
        task,
        common_data.dialog_vocab,
        None,  #common_data.obj_id,
        valid_dialogs,
        knowledge_data if task == KNOWLEDGE_TASK else None)
    test_dataset = Dataset(
        task,
        common_data.dialog_vocab,
        None,  #common_data.obj_id,
        test_dialogs,
        knowledge_data if task == KNOWLEDGE_TASK else None)

    print('Train dataset size:', len(train_dataset))
    print('Valid dataset size:', len(valid_dataset))
    print('Test dataset size:', len(test_dataset))

    # Get initial embedding.
    vocab_size = len(common_data.dialog_vocab)
    embed_init = get_embed_init(common_data.glove,
                                vocab_size).to(GlobalConfig.device)

    # Context model configurations.
    context_text_encoder_config = ContextTextEncoderConfig(
        vocab_size, embed_init)
    context_image_encoder_config = ContextImageEncoderConfig()
    context_encoder_config = ContextEncoderConfig()

    # Context models.
    context_text_encoder = TextEncoder(context_text_encoder_config)
    context_text_encoder = context_text_encoder.to(GlobalConfig.device)
    context_image_encoder = ImageEncoder(context_image_encoder_config)
    context_image_encoder = context_image_encoder.to(GlobalConfig.device)
    context_encoder = ContextEncoder(context_encoder_config)
    context_encoder = context_encoder.to(GlobalConfig.device)

    # Load model file.
    model_file = join(DatasetConfig.dump_dir, model_file_name)
    if isfile(model_file):
        state = torch.load(model_file)
        # if task != state['task']:
        #     raise ValueError("Task doesn't match.")
        context_text_encoder.load_state_dict(state['context_text_encoder'])
        context_image_encoder.load_state_dict(state['context_image_encoder'])
        context_encoder.load_state_dict(state['context_encoder'])

    # Task-specific parts.
    if task == INTENTION_TASK:
        intention_train(context_text_encoder, context_image_encoder,
                        context_encoder, train_dataset, valid_dataset,
                        test_dataset, model_file)
    elif task == TEXT_TASK:
        text_train(context_text_encoder, context_image_encoder,
                   context_encoder, train_dataset, valid_dataset, test_dataset,
                   model_file, common_data.dialog_vocab, embed_init)
    elif task == RECOMMEND_TASK:
        recommend_train(context_text_encoder, context_image_encoder,
                        context_encoder, train_dataset, valid_dataset,
                        test_dataset, model_file, vocab_size, embed_init)
    elif task == KNOWLEDGE_TASK:
        knowledge_attribute_train(context_text_encoder, context_image_encoder,
                                  context_encoder, train_dataset,
                                  valid_dataset, test_dataset, model_file,
                                  knowledge_data.attribute_data,
                                  common_data.dialog_vocab, embed_init)
コード例 #4
0
    device = torch.device("cuda", local_rank)
    checkpoints_dir = './checkpoints'
    start_epoch = 0
    use_bert = True
    if not os.path.exists(checkpoints_dir) and local_rank == 0:
        os.makedirs(checkpoints_dir)
    kdd_dataset = Dataset(use_bert=use_bert)
    sampler = DistributedSampler(kdd_dataset)
    loader = DataLoader(kdd_dataset,
                        collate_fn=collate_fn,
                        batch_size=130,
                        sampler=sampler,
                        num_workers=15)
    nhead = 4
    text_encoder = TextEncoder(kdd_dataset.unknown_token + 1,
                               1024,
                               256,
                               use_bert=use_bert).cuda()
    image_encoder = ImageEncoder(input_dim=2048, output_dim=1024, nhead=nhead)
    image_encoder.load_pretrained_weights(
        path='../user_data/image_encoder_large.pth')
    image_encoder = image_encoder.cuda()
    score_model = ScoreModel(1024, 256).cuda()
    # text_generator = TextGenerator(text_encoder.embed.num_embeddings).cuda()
    # score_model = ScoreModel(30522, 256, num_heads=1).cuda()
    # category_embedding = CategoryEmbedding(256).cuda()

    optimizer = Adam(image_encoder.get_params() + text_encoder.get_params() +
                     score_model.get_params())

    if start_epoch > 0 and local_rank == 0:
        checkpoints = torch.load(
コード例 #5
0
    def load_network(self):
        
        image_generator = ImageGenerator()
        image_generator.apply(weights_init)
        
        disc_image = DiscriminatorImage()
        disc_image.apply(weights_init)
        
        emb_dim = 300
        text_encoder = TextEncoder(emb_dim, self.txt_emb,
                         1, dropout=0.0)
        
        attn_model = 'general'
        text_generator = TextGenerator(attn_model, emb_dim, len(self.txt_dico.id2word), 
                                      self.txt_emb,
                                      n_layers=1, dropout=0.0)    
        
        image_encoder = ImageEncoder()
        image_encoder.apply(weights_init)
        
        disc_latent = DiscriminatorLatent(emb_dim)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
            
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
            
        if cfg.ENCODER != '':
            state_dict = \
                torch.load(cfg.ENCODER,
                           map_location=lambda storage, loc: storage)
            encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.ENCODER)
            
        if cfg.DECODER != '':
            state_dict = \
                torch.load(cfg.DECODER,
                           map_location=lambda storage, loc: storage)
            decoder.load_state_dict(state_dict)
            print('Load from: ', cfg.DECODER)
            
        if cfg.IMAGE_ENCODER != '':
            state_dict = \
                torch.load(cfg.IMAGE_ENCODER,
                           map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.IMAGE_ENCODER)         
            
        if cfg.CUDA:
            image_encoder.cuda()
            image_generator.cuda()
            text_encoder.cuda()
            text_generator.cuda()
            disc_image.cuda()
            disc_latent.cuda()
            
        return image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent
コード例 #6
0
def valid(epoch=1,
          checkpoints_dir='./checkpoints',
          use_bert=False,
          data_path=None,
          out_path='../prediction_result/valid_pred.json',
          output_ndcg=True):
    print("valid epoch{}".format(epoch))
    if data_path is not None:
        kdd_dataset = ValidDataset(data_path, use_bert=use_bert)
    else:
        kdd_dataset = ValidDataset(data_path, use_bert=use_bert)
    loader = DataLoader(kdd_dataset,
                        collate_fn=collate_fn_valid,
                        batch_size=128,
                        shuffle=False,
                        num_workers=8)
    tbar = tqdm(loader)
    text_encoder = TextEncoder(kdd_dataset.unknown_token + 1,
                               1024,
                               256,
                               use_bert=use_bert).cuda()
    image_encoder = ImageEncoder(input_dim=2048, output_dim=1024,
                                 nhead=4).cuda()
    score_model = ScoreModel(1024, 256).cuda()
    # category_embedding = model.CategoryEmbedding(768).cuda()
    checkpoints = torch.load(
        os.path.join(checkpoints_dir, 'model-epoch{}.pth'.format(epoch)))
    text_encoder.load_state_dict(checkpoints['query'])
    image_encoder.load_state_dict(checkpoints['item'])
    score_model.load_state_dict(checkpoints['score'])
    # score_model.load_state_dict(checkpoints['score'])
    outputs = {}
    image_encoder.eval()
    text_encoder.eval()
    score_model.eval()
    for query_id, product_id, query, query_len, features, boxes, category, obj_len in tbar:
        query, query_len = query.cuda(), query_len.cuda()
        query, hidden = text_encoder(query, query_len)
        features, boxes, obj_len = features.cuda(), boxes.cuda(), obj_len.cuda(
        )
        features = image_encoder(features, boxes, obj_len)
        score = score_model(query, hidden, query_len, features)
        score = score.data.cpu().numpy()

        # print(score2)

        for q_id, p_id, s in zip(query_id.data.numpy(),
                                 product_id.data.numpy(), score):
            outputs.setdefault(str(q_id), [])
            outputs[str(q_id)].append((p_id, s))

    for k, v in outputs.items():
        v = sorted(v, key=lambda x: x[1], reverse=True)
        v = [(str(x[0]), float(x[1])) for x in v]
        outputs[k] = v

    with open(out_path, 'w') as f:
        json.dump(outputs, f)
    if output_ndcg:
        pred = read_json(out_path)
        gt = read_json('../data/valid/valid_answer.json')
        score = 0
        k = 5
        for key, val in gt.items():
            ground_truth_ids = [str(x) for x in val]
            predictions = [x[0] for x in pred[key][:k]]
            ref_vec = [1.0] * len(ground_truth_ids)

            pred_vec = [
                1.0 if pid in ground_truth_ids else 0.0 for pid in predictions
            ]
            score += get_ndcg(pred_vec, ref_vec, k)
            # print(key)
            # print([pid for pid in predictions if pid not in ground_truth_ids])
            # print('========')
            # score += len(set(predictions).intersection(ground_truth_ids)) / len(ground_truth_ids)
        score = score / len(gt)
        print('ndcg@%d: %.4f' % (k, score))
        return score
    else:
        return None
コード例 #7
0
    val_dataloader_mohx = DataLoader(dataset=val_dataset_mohx,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     collate_fn=TextDataset.collate_fn)
    """
    3. Model training
    """
    '''
    3. 1 
    set up model, loss criterion, optimizer
    '''
    # Instantiate the model

    Exp_model = TextEncoder(embedding_dim=1024,
                            hidden_size=256,
                            num_layers=1,
                            bidir=True,
                            dropout1=0.5)
    Query_model = TextEncoder(embedding_dim=1024,
                              hidden_size=256,
                              num_layers=1,
                              bidir=True,
                              dropout1=0.5)
    Attn_model = AttentionModel(para_encoder_input_dim=512,
                                query_dim=512,
                                output_dim=256)
    para_encoder_attn_model = AttentionModel(para_encoder_input_dim=512,
                                             query_dim=512,
                                             output_dim=512)

    para_encoder = ParaEncoder(input_dim=1024,