cuda = None                    # the gpu id, e.g., 0 or 1, otherwise, set it as None indicating to use cpu

    debug = True                # in a debug mode, we just check whether the model can operate

    config_with_json = False    # specify configuration with json files or not

    models_to_run = [
        #'IRGAN_Point',
        #'IRGAN_Pair',
        #'IRGAN_List'
        'IRFGAN_Point',
        #'IRFGAN_Pair',
        #'IRFGAN_List'
    ]

    evaluator = AdLTREvaluator(cuda=cuda)

    if config_with_json:  # specify configuration with json files
        # the directory of json files
        #dir_json = '/home/dl-box/WorkBench/ExperimentBench/ALTR/ecir2021/irgan/mq2008_json/'
        dir_json = '/home/dl-box/WorkBench/ExperimentBench/ALTR/ecir2021/irgan/ms30k_json/'

        #dir_json = '/home/dl-box/WorkBench/ExperimentBench/ALTR/ecir2021/irfgan/mq2008_json/'
        dir_json = '/home/dl-box/WorkBench/ExperimentBench/ALTR/ecir2021/irfgan/ms30k_json/'

        #dir_json = '/Users/dryuhaitao/WorkBench/Dropbox/CodeBench/GitPool/irgan_ptranking/testing/ltr_adversarial/json/'

        for model_id in models_to_run:
            evaluator.run(debug=debug, model_id=model_id, config_with_json=config_with_json, dir_json=dir_json)

    else:  # specify configuration manually
    | IRGAN_MQ2008_Semi                                                                      |
    -----------------------------------------------------------------------------------------

    """
    ''' selected dataset '''
    data_id = 'MQ2007_Super'
    ''' location of the adopted data '''
    dir_data = os.path.join(DATASET_DIR, 'MQ2007/')
    #dir_data = '/home/dl-box/WorkBench/Datasets/L2R/LETOR4.0/MQ2008/'
    #dir_data = '/Users/solar/WorkBench/Datasets/L2R/LETOR4.0/MQ2008/'
    ''' output directory '''
    dir_output = os.path.join(PROJECT_OUTPUT_DIR, 'NeuralLTR/Listwise/')
    #dir_output = '/home/dl-box/WorkBench/CodeBench/PyCharmProject/Project_output/Out_L2R/Listwise/'
    #dir_output = '/Users/solar/WorkBench/CodeBench/PyCharmProject/Project_output/Out_L2R/'

    debug = True  # with a debug mode, we can make a quick test, e.g., check whether the model can operate or not

    grid_search = False  # with grid_search, we can explore the effects of different hyper-parameters of a model

    evaluator = AdLTREvaluator()

    to_run_models = ['IRGAN_Pair']

    for model_id in to_run_models:
        evaluator.run(debug=debug,
                      model_id=model_id,
                      data_id=data_id,
                      dir_data=dir_data,
                      dir_output=dir_output,
                      grid_search=grid_search)
Exemple #3
0
    -----------------------------------------------------------------------------------------
    | Yahoo_LTR | Set1 % Set2                                                               |
    -----------------------------------------------------------------------------------------
    | ISTELLA_LTR | Istella_S | Istella | Istella_X                                         |
    -----------------------------------------------------------------------------------------

    """

    args_obj = ArgsUtil(given_root='./')
    l2r_args = args_obj.get_l2r_args()

    if l2r_args.model in LTR_ADHOC_MODEL:
        evaluator = LTREvaluator(cuda=l2r_args.cuda)

    elif l2r_args.model in LTR_ADVERSARIAL_MODEL:
        evaluator = AdLTREvaluator(cuda=l2r_args.cuda)

    elif l2r_args.model in LTR_TREE_MODEL:
        evaluator = TreeLTREvaluator()

    else:
        args_obj.args_parser.print_help()
        sys.exit()

    print('Started evaluation with pt_ranking !')
    evaluator.run(model_id=l2r_args.model,
                  dir_json=l2r_args.dir_json,
                  debug=l2r_args.debug,
                  config_with_json=True)
    print('Finished evaluation with pt_ranking !')
Exemple #4
0
    -----------------------------------------------------------------------------------------
    | LETTOR    | MQ2007_Super %  MQ2008_Super %  MQ2007_Semi %  MQ2008_Semi                |
    -----------------------------------------------------------------------------------------
    | MSLRWEB   | MSLRWEB10K %  MSLRWEB30K                                                  |
    -----------------------------------------------------------------------------------------
    | Yahoo_LTR | Set1 % Set2                                                               |
    -----------------------------------------------------------------------------------------
    | ISTELLA_LTR | Istella_S | Istella | Istella_X                                         |
    -----------------------------------------------------------------------------------------

    """

    print('Started PT_Ranking ...')

    args_obj = ArgsUtil(given_root='./')
    l2r_args = args_obj.get_l2r_args()

    if l2r_args.model in LTR_ADHOC_MODEL:
        evaluator = LTREvaluator()

    elif l2r_args.model in LTR_ADVERSARIAL_MODEL:
        evaluator = AdLTREvaluator()

    elif l2r_args.model in LTR_TREE_MODEL:
        evaluator = TreeLTREvaluator()
    else:
        raise NotImplementedError

    evaluator.run(debug=True, model_id=l2r_args.model, data_id=l2r_args.data_id, dir_data=l2r_args.dir_data, dir_output=l2r_args.dir_output, grid_search=False)