Пример #1
0
def pretrain():
    
    opt = parser.parse_args()

    pretrain_cfg = PretrainConfig.load_from_json(os.path.join(opt.workspace_path, opt.pretrain_cfg_file))
    model_cfg = ModelConfig.load_from_json(os.path.join(opt.workspace_path, opt.model_cfg_file))

    img_file = os.path.join(opt.input_path, opt.img_file)
    corpus_file = os.path.join(opt.input_path, opt.corpus_file)
    eval_corpus_file = os.path.join(opt.input_path, opt.eval_corpus_file)
    vocab_file = os.path.join(opt.input_path, opt.vocab_file)
    video_type_map_file = os.path.join(opt.input_path, opt.video_type_map_file)
    merges_file = os.path.join(opt.input_path, opt.merges_file)
    preprocess_dir = os.path.join(opt.workspace_path, opt.preprocess_dir)
    if not os.path.exists(preprocess_dir):
        os.mkdir(preprocess_dir)
    save_dir = os.path.join(opt.workspace_path, opt.save_dir)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    log_dir = os.path.join(opt.workspace_path, opt.log_dir)
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    if opt.model_file is not None:
        model_file = os.path.join(save_dir, opt.model_file)
    else:
        model_file = None
    
    log_filename = "{}log.txt".format("" if not opt.eval else "eval_")
    log_filename = os.path.join(log_dir,log_filename)
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO,
                    handlers=[logging.FileHandler(os.path.join(opt.log_dir, log_filename)),
                              logging.StreamHandler()])
    logger = logging.getLogger(__name__)
    logger.info(opt)
    
    
    tokenizer = MyBartTokenizer(vocab_file, merges_file)

    dev_data = Dataset(vocab_file, eval_corpus_file, img_file, video_type_map_file, preprocess_dir, model_cfg, pretrain_cfg, imgs=None, is_training=False, type ='pretrain')
    dev_data.load_dataset(tokenizer)
    dev_data.load_dataloader()

    if opt.eval is False:
        train_data = Dataset(vocab_file, corpus_file, img_file, video_type_map_file, preprocess_dir, model_cfg, pretrain_cfg, imgs=dev_data.imgs, is_training=True, type ='pretrain')
        train_data.load_dataset(tokenizer)
        train_data.load_dataloader()
    

    model = MyPLVCG(model_cfg, dev_data.video_type_weight, type="pretrain") 
    
    if opt.eval is False:
        #Train
        if model_file is not None:
            model.load_state_dict(torch.load(model_file))
        
        optimizer = AdamW(filter(lambda p: p.requires_grad,model.parameters()), lr=pretrain_cfg.lr, eps=pretrain_cfg.adam_epsilon)
        
        train(pretrain_cfg, logger, save_dir, model, train_data, dev_data, optimizer, type = 'pretrain')
    
    else:
        #Evaluation
        checkpoint = os.path.join(save_dir, 'best-model.pt')
        logger.info("Loading checkpoint from {}".format(checkpoint))
        model.load_state_dict(torch.load(checkpoint))
        if torch.cuda.is_available():
            model.to(torch.device("cuda"))
        model.eval()

        with(torch.no_grad()):
            total_loss, predictions, predictions_type, logits = inference(pretrain_cfg, model, dev_data, type)


        print_results(save_dir, dev_data, 0, total_loss, predictions, predictions_type, dev_data.comments, dev_data.contexts, dev_data.video_types)
Пример #2
0
def generate():
    
    opt = parser.parse_args()




    generate_cfg = GenerateConfig.load_from_json(os.path.join(opt.workspace_path, opt.generate_cfg_file))
    model_cfg = ModelConfig.load_from_json(os.path.join(opt.workspace_path, opt.model_cfg_file))
    

    
    img_file = os.path.join(opt.input_path, opt.img_file)
    test_corpus_file = os.path.join(opt.input_path, opt.test_corpus_file)
    vocab_file = os.path.join(opt.input_path, opt.vocab_file)
    merges_file = os.path.join(opt.input_path, opt.merges_file)
    video_type_map_file = os.path.join(opt.input_path, opt.video_type_map_file)
    preprocess_dir = os.path.join(opt.workspace_path, opt.preprocess_dir)
    save_dir = os.path.join(opt.workspace_path, opt.save_dir)
    generate_dir = os.path.join(opt.workspace_path, opt.generate_dir)
    log_dir = os.path.join(opt.workspace_path, opt.log_dir)

    model_file = os.path.join(save_dir, opt.model_file)


    log_filename = "generate_log.txt"
    log_filename = os.path.join(log_dir,log_filename)
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO,
                    handlers=[logging.FileHandler(os.path.join(opt.log_dir, log_filename)),
                              logging.StreamHandler()])
    logger = logging.getLogger(__name__)
    logger.info(opt)
    
    
    tokenizer = MyBartTokenizer(vocab_file, merges_file)

    test_data = Dataset(vocab_file, test_corpus_file, img_file, video_type_map_file, preprocess_dir, model_cfg, generate_cfg, imgs=None, is_training=False, type = 'generate')
    test_data.load_gengrate_dataset(tokenizer)
    test_data.load_dataloader()


    
    model = MyPLVCG(model_cfg, test_data.video_type_weight, type="fine_tuning")
    model.load_state_dict(torch.load(model_file))
    if torch.cuda.is_available():
        model.to(torch.device("cuda"))
    model.eval()

    with(torch.no_grad()):
        generated = test_generation(generate_cfg, model, test_data)

    res_f = open(os.path.join(generate_dir, 'generated.txt'),"w", encoding='utf8')
    contexts = test_data.contexts
    ground_truth = test_data.comments
    
    for gen,gt,ct in zip(generated,ground_truth,contexts):
        
        ct_decode = test_data.decode(ct)
        end = ct_decode.find("<PAD>")
        if end != -1:
            ct_decode = ct_decode[:end]
        res_f.write("%s\n\nground_truth:\n"%(ct_decode))
        for g in gt:
            g_decode = test_data.decode(g)
            end = g_decode.find("<EOS>")
            if end != -1:
                g_decode = g_decode[:end]
            res_f.write("\t%s\n"%(g_decode))
        res_f.write("\ngenerated:\n")
        for s in gen:
            end = s.find("<EOS>")
            if end != -1:
                s = s[:end]
            res_f.write("\t%s\n"%(s))
        res_f.write("\n=============================\n\n")
Пример #3
0
def classification():

    opt = parser.parse_args()

    classification_cfg = ClassificationConfig.load_from_json(
        os.path.join(opt.workspace_path, opt.classification_cfg_file))
    model_cfg = ModelConfig.load_from_json(
        os.path.join(opt.workspace_path, opt.model_cfg_file))

    img_file = os.path.join(opt.input_path, opt.img_file)
    corpus_file = os.path.join(opt.input_path, opt.corpus_file)
    eval_corpus_file = os.path.join(opt.input_path, opt.eval_corpus_file)
    vocab_file = os.path.join(opt.input_path, opt.vocab_file)
    merges_file = os.path.join(opt.input_path, opt.merges_file)
    preprocess_dir = os.path.join(opt.workspace_path, opt.preprocess_dir)
    video_type_map_file = os.path.join(opt.input_path, opt.video_type_map_file)
    save_dir = os.path.join(opt.workspace_path, opt.save_dir)
    log_dir = os.path.join(opt.workspace_path, opt.log_dir)
    if opt.model_file is not None:
        model_file = os.path.join(save_dir, opt.model_file)
    else:
        model_file = None
    pretrain_file = os.path.join(opt.workspace_path, opt.pretrain_file)

    log_filename = "{}log.txt".format("" if not opt.eval else "eval_")
    log_filename = os.path.join(log_dir, log_filename)
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler(os.path.join(opt.log_dir, log_filename)),
            logging.StreamHandler()
        ])
    logger = logging.getLogger(__name__)
    logger.info(opt)

    tokenizer = MyBartTokenizer(vocab_file, merges_file)

    dev_data = Dataset(vocab_file,
                       eval_corpus_file,
                       img_file,
                       video_type_map_file,
                       preprocess_dir,
                       model_cfg,
                       classification_cfg,
                       imgs=None,
                       is_training=False,
                       type='fine_tuning')
    dev_data.load_classification_dataset(tokenizer)
    dev_data.load_dataloader()

    if opt.eval is False:
        train_data = Dataset(vocab_file,
                             corpus_file,
                             img_file,
                             video_type_map_file,
                             preprocess_dir,
                             model_cfg,
                             classification_cfg,
                             imgs=dev_data.imgs,
                             is_training=True,
                             type='fine_tuning')
        train_data.load_classification_dataset(tokenizer)
        train_data.load_dataloader()

    model = MyClassificationPLVCG(model_cfg,
                                  classification_cfg.negative_num,
                                  type="classification")

    if not opt.without_pretrain:
        model_dict = model.state_dict()
        print("Loading pretrain file...")
        pretrained_dict = torch.load(pretrain_file)

        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }

        model_dict.update(pretrained_dict)

        model.load_state_dict(model_dict)

    gt = np.array(([1] + [0] * (5 - 1)) * (100 // 5))

    if opt.eval is False:
        #Train
        if model_file is not None:
            model.load_state_dict(torch.load(model_file))

        optimizer = AdamW(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=classification_cfg.lr,
                          eps=classification_cfg.adam_epsilon)
        train(classification_cfg,
              logger,
              save_dir,
              model,
              train_data,
              dev_data,
              optimizer,
              type='classification')

    else:
        #Evaluation
        checkpoint = os.path.join(save_dir, 'best-model.pt')
        model = MyClassificationPLVCG(model_cfg,
                                      classification_cfg.negative_num)
        logger.info("Loading checkpoint from {}".format(checkpoint))
        model.load_state_dict(torch.load(checkpoint))
        if torch.cuda.is_available():
            model.to(torch.device("cuda"))
        model.eval()

        with (torch.no_grad()):
            total_loss, predictions, logits = inference(
                classification_cfg, model, dev_data, 'classification')

        for i in range(10):
            print(0.3 + i * 0.05)
            predictions = logit_pred(logits, 0.3 + i * 0.05)

            precision, recall, f1 = metrics(predictions,
                                            classification_cfg.negative_num)
            print("precision:%f \t recall:%f \t f1:%f" %
                  (precision, recall, f1))
            rids = dev_data.dataset.rids
            print_classification_res(save_dir, dev_data, 0, total_loss,
                                     predictions, dev_data.comments,
                                     dev_data.contexts, rids, logits,
                                     classification_cfg.negative_num)
Пример #4
0
def ranking():
    
    opt = parser.parse_args()


    rank_config = RankConfig.load_from_json(os.path.join(opt.workspace_path, opt.rank_cfg_file))
    model_cfg = ModelConfig.load_from_json(os.path.join(opt.workspace_path, opt.model_cfg_file))
    

    
    img_file = os.path.join(opt.input_path, opt.img_file)
    test_corpus_file = os.path.join(opt.input_path, opt.test_corpus_file)
    vocab_file = os.path.join(opt.input_path, opt.vocab_file)
    merges_file = os.path.join(opt.input_path, opt.merges_file)
    video_type_map_file = os.path.join(opt.input_path, opt.video_type_map_file)
    preprocess_dir = os.path.join(opt.workspace_path, opt.preprocess_dir)
    rank_dir = os.path.join(opt.workspace_path, opt.rank_dir)
    log_dir = os.path.join(opt.workspace_path, opt.log_dir)
    save_dir = os.path.join(opt.workspace_path, opt.save_dir)
    model_file = os.path.join(save_dir, opt.model_file)
    

    log_filename = "{}log.txt".format("rank_")
    log_filename = os.path.join(log_dir,log_filename)
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO,
                    handlers=[logging.FileHandler(os.path.join(opt.log_dir, log_filename)),
                              logging.StreamHandler()])
    logger = logging.getLogger(__name__)
    logger.info(opt)
    
    
    tokenizer = MyBartTokenizer(vocab_file, merges_file)


    test_data = Dataset(vocab_file, test_corpus_file, img_file, video_type_map_file, preprocess_dir, model_cfg, rank_config, imgs=None, is_training=False, type = 'test')
    test_data.load_test_dataset(tokenizer)
    test_data.load_dataloader()
    
    
    if opt.model_from == 'classification':
        model = MyClassificationPLVCG(model_cfg, type='test')
        logger.info("Loading checkpoint from {}".format(model_file))
        model.load_state_dict(torch.load(model_file))
    else:
        model = MyPLVCG(model_cfg, type='test')
        logger.info("Loading checkpoint from {}".format(model_file))
        model.load_state_dict(torch.load(model_file))
    if torch.cuda.is_available():
        model.to(torch.device("cuda"))
    model.eval()

    if opt.load:
        with open(os.path.join(rank_dir,'rank_score_%s.json'%(opt.model_from)), "r") as f:
            scores, pred_list = json.load(f)
            ranks = [sorted(range(len(score)), key=lambda k: score[k],reverse=True) for score in scores]
            # ============================= for random ================================
            #``random.shuffle (ranks )
            # ============================= for random ================================
    else:
        with(torch.no_grad()):
            ranks, scores, pred_list = test_rank(rank_config, model, test_data, type='classification')
        f_scores =  open(os.path.join(rank_dir,'rank_score_%s.json'%(opt.model_from)),'w', encoding='utf8')
        scores = [np.array(s.cpu()).tolist() for s in scores]
        json.dump([scores,pred_list], f_scores)
    
    predictions = []
    references = []
    hits_1 = 0
    hits_5 = 0
    hits_10 = 0
    mean_rank = 0
    mean_reciprocal_rank = 0
    
    f_outs =  open(os.path.join(rank_dir,'out.txt'),'w', encoding='utf8')
    
    for i, rank in enumerate(ranks):
        gt_dic = test_data.gts[i]
        pred_b = pred_list[i]


        candidate = []
        comments = list(gt_dic.keys())
        for id in rank:
            candidate.append(comments[id])
        f_outs.write("\n========================\n")
        predictions.append(candidate)
        references.append(gt_dic)
        
        hit_rank = calc_hit_rank(candidate, gt_dic)
        
        f_outs.write("%d\n"%(hit_rank))
        cont = test_data.decode(test_data.contexts[i])
        end = cont.find("<PAD>")
        if end != -1:
            cont = cont[:end]


        f_outs.write("%s\n"%(cont))
        for j,id in enumerate(rank):
            
            
            if opt.model_from == 'classification':
                p = pred_b
                f_outs.write("%d %d %d %f %d %s || %d\n"%(i,j,rank[j],scores[i][rank[j]],gt_dic[comments[id]],comments[id],p))
            else:
                p = pred_b[rank[j]]
                f_outs.write("%d %d %d %f %d %s || %s\n"%(i,j,rank[j],scores[i][rank[j]],gt_dic[comments[id]],comments[id],p))
        

        mean_rank += hit_rank
        mean_reciprocal_rank += 1.0/hit_rank
        hits_1 += int(hit_rank <= 1)
        hits_5 += int(hit_rank <= 5)
        hits_10 += int(hit_rank <= 10)

        #for j,g in enumerate(gt_dic.keys()):
        #    print(scores[i][j], g, gt_dic[g])
    f_outs.close()
    total = len(test_data.gts)
    
    f_o = open(os.path.join(rank_dir, 'rank_res.txt'),'w', encoding='utf8')
    print("\t r@1:%f \t r@5:%f \t r@10:%f \t mr:%f \t mrr:%f"%(hits_1/total*100,hits_5/total*100,hits_10/total*100,mean_rank/total,mean_reciprocal_rank/total))
    f_o.write("\t r@1:%f \t r@5:%f \t r@10:%f \t mr:%f \t mrr:%f"%(hits_1/total*100,hits_5/total*100,hits_10/total*100,mean_rank/total,mean_reciprocal_rank/total))