Exemplo n.º 1
0
def main():
    #加载日志
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO)
    logger = logging.getLogger(__name__)

    # 'opt' means options
    opt = None

    parser = argparse.ArgumentParser(description='multi-qa')
    parser.add_argument('mode', help='mode: train')
    parser.add_argument('conf_file', help='path to conf file.')
    parser.add_argument('dataset', help='dataset for train')

    #得到参数实例
    args = parser.parse_args()
    mode = args.mode # train mode

    #得到conf文件的参数
    conf_file = args.conf_file
    dataset = args.dataset
    conf_args = Arguments(conf_file)
    opt = conf_args.readArguments()

    opt['cuda'] = torch.cuda.is_available()
    opt['confFile'] = conf_file
    opt['datadir'] = os.path.dirname(conf_file) # .
    for key, val in args.__dict__.items():
        if val is not None and key not in ['command', 'conf_file']:
            opt[key] = val

    #使用GPU
    # opt['cuda'] = False
    device = torch.device("cuda" if opt['cuda'] else 'cpu')
    logger.info("device %s is using for training!", device)
    model = ConvQA_CN_NetTrainer(opt)
    print("Select mode----", mode)
    print("Using dataset ----", dataset)
    model.train()
Exemplo n.º 2
0
def main(conf_file, model_file, data_file, output_path, mask, indicator,
         answer_span_in_context, no_ans_bit):
    conf_args = Arguments(conf_file)
    opt = conf_args.readArguments()
    opt['cuda'] = torch.cuda.is_available()
    opt['confFile'] = conf_file
    opt['datadir'] = os.path.dirname(conf_file)
    opt['PREV_ANS_MASK'] = mask
    opt['PREV_ANS_INDICATOR'] = indicator

    opt['OFFICIAL'] = True
    opt['OFFICIAL_TEST_FILE'] = data_file
    if answer_span_in_context:
        opt['ANSWER_SPAN_IN_CONTEXT_FEATURE'] = None
    if no_ans_bit:
        opt['NO_PREV_ANS_BIT'] = None
    trainer = SDNetTrainer(opt)
    test_data = trainer.preproc.preprocess('test')
    predictions, confidence, final_json = trainer.official(
        model_file, test_data)
    with output_path.open(mode='w') as f:
        json.dump(final_json, f)
Exemplo n.º 3
0
from Models.SDNetTrainer import SDNetTrainer
from Utils.Arguments import Arguments
if __name__ == "__main__":
    # multiprocessing.set_start_method('spawn')
    opt = None

    parser = argparse.ArgumentParser(description='SDNet')
    parser.add_argument('--command', default='train', help='Command: train')
    parser.add_argument('--conf_file', default='conf_stvqa', help='Path to conf file.')
    parser.add_argument('--log_file', default='', help='Path to log file.')

    cmdline_args = parser.parse_args()
    command = cmdline_args.command
    conf_file = cmdline_args.conf_file
    conf_args = Arguments(conf_file)
    opt = conf_args.readArguments()
    opt['cuda'] = torch.cuda.is_available()
    opt['confFile'] = conf_file
    opt['datadir'] = os.path.dirname(conf_file)  # conf_file specifies where the data folder is

    if cmdline_args.log_file != '':
        if not os.path.exists('myLog'):
            os.makedirs('myLog')
        file_handle = logging.FileHandler(os.path.join('myLog', cmdline_args.log_file +'.txt'))
        log.addHandler(file_handle)

    for key,val in cmdline_args.__dict__.items():
        if val is not None and key not in ['command', 'conf_file']:
            opt[key] = val

    model = SDNetTrainer(opt)