示例#1
0
def run_model(task=None, model_name=None, dataset_name=None, config_file=None,
              save_model=True, train=True, other_args=None):
    """
    Args:
        task(str): task name
        model_name(str): model name
        dataset_name(str): dataset name
        config_file(str): config filename used to modify the pipeline's
            settings. the config file should be json.
        save_model(bool): whether to save the model
        train(bool): whether to train the model
        other_args(dict): the rest parameter args, which will be pass to the Config
    """
    # load config
    config = ConfigParser(task, model_name, dataset_name,
                          config_file, other_args)
    # logger
    logger = get_logger(config)
    logger.info('Begin pipeline, task={}, model_name={}, dataset_name={}'.
                format(str(task), str(model_name), str(dataset_name)))
    # 加载数据集
    dataset = get_dataset(config)
    # 转换数据,并划分数据集
    train_data, valid_data, test_data = dataset.get_data()
    data_feature = dataset.get_data_feature()
    '''
    #add by 18231216
    import torch
    towr = open("oup.txt",'w')
    for i,ement in enumerate(train_data):
        for key in ement.data:
            print("key:{},body:{}".format(key,torch.Tensor(ement.data[key])),file = towr)
            print("#####################################",file = towr)
    towr.close()
    #'''
    '''
    #add by 18231216
    towr = open("oup.txt",'w')
    for key in enumerate(dataset.data):
        print("key:{},shape:{}".format(key,torch.Tensor(dataset.data.data[key]).shape),file = towr)
    towr.close()
    '''
    # 加载执行器
    model_cache_file = './trafficdl/cache/model_cache/{}_{}.m'.format(
        model_name, dataset_name)
    model = get_model(config, data_feature)
    executor = get_executor(config, model)
    # 训练
    if train or not os.path.exists(model_cache_file):
        executor.train(train_data, valid_data)
        if save_model:
            executor.save_model(model_cache_file)
    else:
        executor.load_model(model_cache_file)
    # 评估,评估结果将会放在 cache/evaluate_cache 下
    executor.evaluate(test_data)
示例#2
0
def run_model(task=None,
              model_name=None,
              dataset_name=None,
              config_file=None,
              save_model=True,
              train=True,
              other_args=None):
    """
    Args:
        task (str): task name
        model_name (str): model name
        dataset_name (str): dataset name
        config_file (str): config filename used to modify the pipeline's
            settings. the config file should be json.
        save_model (bool): whether to save the model
        train (bool): whether to train the model
        other_args (dict): the rest parameter args, which will be pass to
            the Config
    """

    # load config
    config = ConfigParser(task, model_name, dataset_name, config_file,
                          other_args)
    # logger
    logger = get_logger(config)
    logger.info(
        'Begin pipeline, task={}, model_name={}, dataset_name={}'.format(
            str(task), str(model_name), str(dataset_name)))
    # 加载数据集
    dataset = get_dataset(config)
    # 转换数据,并划分数据集
    train_data, valid_data, test_data = dataset.get_data()
    data_feature = dataset.get_data_feature()
    # 加载执行器
    model_cache_file = './trafficdl/cache/model_cache/{}_{}.m'.format(
        model_name, dataset_name)
    model = get_model(config, data_feature)
    executor = get_executor(config, model)
    # 训练
    if train or not os.path.exists(model_cache_file):
        executor.train(train_data, valid_data)
        if save_model:
            executor.save_model(model_cache_file)
    else:
        executor.load_model(model_cache_file)
    # 评估,评估结果将会放在 cache/evaluate_cache 下
    executor.evaluate(test_data)
示例#3
0
    False,
    'use_early_stop':
    True,
    'max_grad_norm':
    5,
    'patience':
    15,
}

os.environ["CUDA_VISIBLE_DEVICES"] = config['gpu_id']
config['device'] = torch.device(
    "cuda" if torch.cuda.is_available() and config['gpu'] else "cpu")

logger = get_logger(config)
# 加载数据集
dataset = get_dataset(config)
# 转换数据,并划分数据集
train_data, valid_data, test_data = dataset.get_data()
print(len(train_data), len(train_data.dataset), train_data.dataset[0][0].shape,
      train_data.dataset[0][1].shape, train_data.batch_size)
print(len(valid_data), len(valid_data.dataset), valid_data.dataset[0][0].shape,
      valid_data.dataset[0][1].shape, valid_data.batch_size)
print(len(test_data), len(test_data.dataset), test_data.dataset[0][0].shape,
      test_data.dataset[0][1].shape, test_data.batch_size)

data_feature = dataset.get_data_feature()
print(data_feature['adj_mx'].shape)
print(data_feature['adj_mx'].sum())

model = get_model(config, data_feature)