Exemple #1
0
def main():
    random_seed = 1
    torch.backends.cudnn.enabled = False
    torch.manual_seed(random_seed)
    torch.cuda.empty_cache()
    
    time1 = time.time()
    args = get_args()
    TRNG_MB_SIZE =  args.batch_train
    TEST_MB_SIZE =  args.batch_test
    EPOCHS =        args.epochs
    LEARNING_RATE = args.learning_rate
    OPTIM =         args.optimizer
    
    DO_TRAIN =      args.do_train
    DO_TEST =       args.do_test
    
    MODEL_NAME =    args.model_name
    EXP_NAME =      args.exp_name
    
    #DEV_DATA =      args.dev_data
    TRAIN_DATA =    args.train_data
    TEST_DATA =     args.test_data
    KFOLDS =        args.k_folds
    FOLDS2RUN =     args.folds2run
    
    DEBUG =         args.debug
    LOG_INTERVAL =  args.log_interval
    ''' ===================================================='''
    ''' ---------- Parse addtional arguments here ----------'''
    LOSS_FN =       args.loss_fn
    W_SAMPLE =      args.w_sample
    PRETRAIN =      args.pretrain_model
    EPOCHS2GIVEUP = args.epochs2giveup
    DROPOUT =       args.dropout
    LAYERS =        args.layers
    V_ATTR =        args.viral_attr
    V_LOG =         args.viral_log
    W_ATTR =        args.weight_attr 
    TASK =          args.task
    MTT_WEIGHT =    args.mtt_weight
    ABLATION =      args.ablation
    ''' ===================================================='''
    
    model_savefile = './log_files/saved_models/'+EXP_NAME+'_'+MODEL_NAME+'.bin'   # to save/load model from
    plotfile = './log_files/'+EXP_NAME+'_'+MODEL_NAME+'.png'            # to plot losses
    if DO_TRAIN:
        logfile_name = './log_files/'+EXP_NAME+'_'+MODEL_NAME+'.log'    # for recording training progress
    else:
        logfile_name = './log_files/'+EXP_NAME+'_'+MODEL_NAME+'.test'
    
    file_handler = logging.FileHandler(filename=logfile_name)       # for saving into a log file
    stdout_handler = logging.StreamHandler(sys.stdout)              # for printing onto terminal
    stderr_handler = logging.StreamHandler(sys.stderr)              # for printing errors onto terminal
    
    handlers1 = [file_handler, stdout_handler]
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt= '%m/%d/%Y %H:%M:%S', handlers=handlers1, level=logging.INFO)
    handlers2 = [file_handler, stderr_handler]
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt= '%m/%d/%Y %H:%M:%S', handlers=handlers2, level=logging.ERROR)
    
    logger = logging.getLogger(__name__)
    logger.info('----------------- Hyperparameters ------------------')
    logger.info('======== '+MODEL_NAME+' =========')
    logger.info('===== Hyperparameters ======')
    for eachline in vars(args).items():
        logger.info(eachline)
        
    logger.info('--------------- Getting dataframes -----------------')
    test_df = torch.load(TEST_DATA)
    full_train_df = torch.load(TRAIN_DATA)
    
    if DEBUG:
        test_df = test_df[0:40]
        full_train_df = test_df
    
    if V_ATTR == 'likes':
        viral_score = test_df.favorite_count
    elif V_ATTR =='retweets':
        viral_score = test_df.retweets_count
    else:
        raise Exception ('V_ATTR not found: '+ V_ATTR)
    
    top_percentiles = [10,20,30,40,50]              # percentiles to analyse
    for pctile in top_percentiles:
        thr = np.percentile(viral_score,            # get threshold
                            100-pctile)             # percentile function arg is CDF, so must minus 100
        top_ranked = (viral_score >= thr)           # label all posts as not viral
        string = 'top_ranked_'+str(pctile)          # column title 
        test_df[string] = top_ranked                # stick labels into dataframe
    
    test_dl = dataloader.df_2_dl_v6(test_df, 
                                    batch_size=TEST_MB_SIZE, 
                                    randomize=False,
                                    viral_attr=V_ATTR,
                                    logger=logger,
                                    ablation=ABLATION)
    
    TESTLENGTH = len(test_df)
    if DO_TRAIN:
        logger.info('-------------- Setting loss function  --------------')
        # weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 20.0, 1.0, 10.0, 10.0, 1.0, 1.0, 1.0]).to(gpu)
        # loss_fn = torch.nn.CrossEntropyLoss(weight=weights, reduction='mean')
        if LOSS_FN == 'dice':
            logger.info('chose dice')
            loss_fn_s = SelfAdjDiceLoss(reduction='mean')           # for stance
        elif LOSS_FN == 'ce_loss':
            logger.info('chose ce_loss')
            loss_fn_s = torch.nn.CrossEntropyLoss(reduction='mean') # for stance
        elif LOSS_FN == 'w_ce_loss':            
            logger.info('chose w_ce_loss')
            # count number of examples per category for stance
            stance_counts = torch.tensor([.1, .1, .1, .1])          # memory for storing counts
            for stance in full_train_df.number_labels_4_types:      # for each label type
                stance_counts [stance] += 1                         # count occurences
            stance_weights = 1.0 / stance_counts                    # inverse counts to get weights
            stance_weights = stance_weights / stance_weights.mean() # normalize so mean is 1
            
            logger.info('stance loss weights')
            logger.info(stance_weights)
            
            loss_fn_s = torch.nn.CrossEntropyLoss(reduction='mean', # loss function for stance
                                                  weight=stance_weights.cuda()) 
        else:
            raise Exception('Loss function not found: ' + LOSS_FN)
        
        loss_fn_v = torch.nn.MSELoss(reduction='mean')
        
        kfold_helper = KFold(n_splits=KFOLDS)
        kfolds_ran = 0
        kfolds_devs = []
        kfolds_tests= []
        
        for train_idx, dev_idx in kfold_helper.split(full_train_df):
            logger.info('--------------- Running KFOLD %d / %d ----------------' % (kfolds_ran+1, KFOLDS))
            logger.info(print_gpu_obj())
            if FOLDS2RUN == 0:  # for debugging purposes
                train_df = full_train_df
                dev_df = full_train_df
            else:
                train_df = full_train_df.iloc[train_idx]
                dev_df = full_train_df.iloc[dev_idx]
            
            logger.info('------------ Converting to dataloaders -------------')
            train_dl = dataloader.df_2_dl_v6(train_df, 
                                             batch_size=TRNG_MB_SIZE, 
                                             randomize=True, 
                                             weighted_sample=W_SAMPLE, 
                                             weight_attr=W_ATTR,
                                             viral_attr=V_ATTR,
                                             logger=logger,
                                             ablation=ABLATION)
            dev_dl = dataloader.df_2_dl_v6(dev_df, 
                                           batch_size=TEST_MB_SIZE, 
                                           randomize=False, 
                                           weighted_sample=False,
                                           viral_attr=V_ATTR,
                                           logger=logger,
                                           ablation=ABLATION)
            
            logger.info('--------------- Getting fresh model ----------------')
            model = get_model(logger,MODEL_NAME, DROPOUT, LAYERS)
            model.cuda()
            model = torch.nn.DataParallel(model)
            if PRETRAIN != '':  # reload pretrained model 
                logger.info('loading pretrained model file ' + PRETRAIN)
                saved_params = torch.load(PRETRAIN)
                model.load_state_dict(saved_params)
                del saved_params
            
            logger.info('----------------- Setting optimizer ----------------')
            if OPTIM=='adam':
                optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
            else:
                raise Exception('Optimizer not found: ' + optimizer)
            
            logger.info('------ Running a random test before training -------')
            _, _, _, _, random_test_idx = test_single_example(model=model, 
                                                              datalen=TESTLENGTH, 
                                                              dataloader=test_dl, 
                                                              logger=logger, 
                                                              log_interval=LOG_INTERVAL, 
                                                              v_log=V_LOG,
                                                              index=-1, show=True)
            
            logger.info('---------------- Starting training -----------------')
            plotfile_fold = plotfile.replace('.png', '_fold'+str(kfolds_ran)+'.png')
            model_savefile_fold = model_savefile.replace('.bin', '_fold'+str(kfolds_ran)+'.bin')
            fold_metrics = train(model=model, train_dl=train_dl, dev_dl=dev_dl, 
                                 logger=logger, log_interval=LOG_INTERVAL, epochs=EPOCHS,
                                 loss_fn_s=loss_fn_s, loss_fn_v=loss_fn_v, optimizer=optimizer, 
                                 v_log=V_LOG, top_percentiles=top_percentiles, 
                                 plotfile=plotfile_fold, modelfile=model_savefile_fold,
                                 epochs_giveup=EPOCHS2GIVEUP,
                                 task=TASK, mtt_weight=MTT_WEIGHT)
            
            kfolds_devs.append(fold_metrics)
            
            # reload best models
            saved_params = torch.load(model_savefile_fold)
            model.load_state_dict(saved_params)
            del saved_params # this is a huge memory sucker
            with torch.no_grad(): # run some tests post training
                logger.info('------ Running same random test post training ------')
                test_single_example(model=model, 
                                    datalen=TESTLENGTH, 
                                    dataloader=test_dl, 
                                    logger=logger, 
                                    log_interval=LOG_INTERVAL,
                                    v_log=V_LOG,
                                    index=random_test_idx, 
                                    show=True)
                
                logger.info('------- Running on test set after training  --------')
                test_results = test(model=model, 
                                    dataloader=test_dl,
                                    logger=logger,
                                    log_interval=LOG_INTERVAL,
                                    v_log=V_LOG,
                                    print_string='test')
                
                y_pred_s = test_results[0]   # shape=(n,). elements are ints.
                y_pred_v = test_results[1]   # shape=(n,). elements are floats
                y_true_s = test_results[2]   # shape=(n,). elements are ints
                y_true_v = test_results[3]   # shape=(n,). elements are floats
                
                f1_metrics_s = f1_help(y_true_s, y_pred_s,  # calculate f1 scores for stance
                                       average=None,        # dont set to calculate for all
                                       labels=[0,1,2,3])    # number of classes = 4
                metrics_v = calc_rank_scores_at_k(y_true_v,
                                                  y_pred_v,
                                                  top_percentiles)
                
                prec_s, rec_s, f1s_s, supp_s = f1_metrics_s
                acc_s = calculate_acc(y_pred_s, y_true_s)
                msg_s = f1_metrics_msg_stance(prec_s, rec_s, f1s_s, supp_s, acc_s)
                
                prec_v, supp_v, ndcg_v = metrics_v
                r2e_v = r2_score(y_true_v, y_pred_v)
                mse_v = mean_squared_error(y_true_v, y_pred_v)
                msg_v = metrics_msg_viral(prec_v, supp_v, ndcg_v, top_percentiles, r2e_v, mse_v)
                
                logger.info(msg_s + msg_v)
                kfolds_tests.append([f1_metrics_s, acc_s, r2e_v, mse_v, msg_s+msg_v])
                time2 = time.time()
                logger.info(fmt_time_pretty(time1, time2))
            # ===================================================================
            # need to do these steps to force garbage collection to work properly
            # without it, the model deletion doesnt seem to work properly
            model.to('cpu') 
            del optimizer, model, train_dl, dev_dl
            gc.collect()
            torch.cuda.empty_cache()
            # ===================================================================
            
            kfolds_ran += 1
            if kfolds_ran >= FOLDS2RUN:
                break
        # finished kfolds, print everything once more, calculate the average f1 metrics
        f1s_s = []  # to accumulate stance f1 scores
        r2es_v = []  # to accumulate viral r2 scores
        mses_v = []  # to accumulate viral mse scores
        accs_s = [] # to accumulate stance accuracy scores
        
        for i in range(len(kfolds_devs)):
            fold_dev_results = kfolds_devs[i]
            fold_test_results = kfolds_tests[i]
            dev_msg = fold_dev_results[-1]
            test_msg = fold_test_results[-1]
            msg_2_print =               '\n******************** Fold %d results ********************\n' % i 
            msg_2_print = msg_2_print + '------------------------ Dev set ------------------------' + dev_msg 
            msg_2_print = msg_2_print + '------------------------ Test set ------------------------' + test_msg
            logger.info(msg_2_print)

            f1_s_metrics = fold_test_results[0]
            acc_s = fold_test_results[1]            
            r2e_v = fold_test_results[2]
            mse_v = fold_test_results[3]
            
            f1_s = np.average(f1_s_metrics [2]) # get individual class f1 scores, then avg
            f1s_s.append(f1_s)                  # store macro f1 
            accs_s.append(acc_s)                # store accuracy
            r2es_v.append(r2e_v)                # store the r2e
            mses_v.append(mse_v)                # store the mse 
            
        
        f1_s_avg = np.average(f1s_s)
        f1_s_std = np.std(f1s_s)
        r2_v_avg = np.average(r2es_v)
        r2_v_std = np.std(r2es_v)
        mse_v_avg = np.average(mses_v)
        mse_v_std = np.std(mses_v)
        acc_s_avg= np.average(accs_s)
        acc_s_std = np.std(accs_s)
        
        msg = '\nPerf across folds\n'
        msg+= 'avg_f1_stance\t%.4f\n' % f1_s_avg
        msg+= 'std_f1_stance\t%.4f\n' % f1_s_std
        msg+= 'avg_r2_viral\t%.4f\n' % r2_v_avg
        msg+= 'std_r2_viral\t%.4f\n' % r2_v_std
        msg+= 'avg_mse_viral\t%.4f\n' % mse_v_avg
        msg+= 'std_mse_viral\t%.4f\n' % mse_v_std
        msg+= 'avg_acc_stance\t%.4f\n' % acc_s_avg
        msg+= 'std_acc_stance\t%.4f\n' % acc_s_std
        logger.info(msg)
        
    if DO_TEST:
        logger.info('------------------ Getting model -------------------')
        model = get_model(logger,MODEL_NAME, DROPOUT, LAYERS)
        model.cuda()
        model = torch.nn.DataParallel(model)
        if PRETRAIN != '':  # reload pretrained model 
            logger.info('loading pretrained model file ' + PRETRAIN)
            saved_params = torch.load(PRETRAIN)
            model.load_state_dict(saved_params)
        
        test_results = test(model=model, 
                            dataloader=test_dl,
                            logger=logger,
                            log_interval=LOG_INTERVAL,
                            v_log=V_LOG,
                            print_string='test')
        
        y_pred_s = test_results[0]   # shape=(n,). elements are ints.
        y_pred_v = test_results[1]   # shape=(n,). elements are floats
        y_true_s = test_results[2]   # shape=(n,). elements are ints
        y_true_v = test_results[3]   # shape=(n,). elements are floats
        
        f1_metrics_s = f1_help(y_true_s, y_pred_s,  # calculate f1 scores for stance
                               average=None,        # dont set to calculate for all
                               labels=[0,1,2,3])    # number of classes = 4
        metrics_v = calc_rank_scores_at_k(y_true_v,
                                          y_pred_v,
                                          top_percentiles)
        
        prec_s, rec_s, f1s_s, supp_s = f1_metrics_s
        acc_s = calculate_acc(y_pred_s, y_true_s)
        msg_s = f1_metrics_msg_stance(prec_s, rec_s, f1s_s, supp_s, acc_s)
        
        prec_v, supp_v, ndcg_v = metrics_v        
        r2e_v = r2_score(y_true_v, y_pred_v)
        mse_v = mean_squared_error(y_true_v, y_pred_v)
        msg_v = metrics_msg_viral(prec_v, supp_v, ndcg_v, top_percentiles, r2e_v, mse_v)
        
        logger.info(msg_s + msg_v)
        
    time2 = time.time()
    logger.info(fmt_time_pretty(time1, time2))
    return
Exemple #2
0
                        s=2,
                        alpha=0.3,
                        color=colors[cluster])
        
        centers = model.cluster_centers_
        for i in range(centers.shape[0]):
            plt.scatter(centers[i,0], 
                        centers[i,1], 
                        marker='x',
                        color=colors[i])
        
        # find the points closest to centroids
        closest, _ = pairwise_distances_argmin_min(model.cluster_centers_, X)
        # annotate text for points closest to centroids
        i = 0
        for idx in closest:
            data = df.iloc[idx]
            plt.annotate(text=data.keywords,
                         xy=(X1[idx],X2[idx]),
                         color=colors[i],
                         horizontalalignment='center',
                         verticalalignment='top',
                         size=20)
            i += 1
        plt.suptitle('PERPLEXITY %d, PCA-DIM %d + KMEANS' %(PERPLEX, PCA_DIM))
        
time2 = time.time()
print(fmt_time_pretty(time1, time2))