def main(conf_file, model_file, data_file, output_path, mask, indicator, answer_span_in_context, no_ans_bit): conf_args = Arguments(conf_file) opt = conf_args.readArguments() opt['cuda'] = torch.cuda.is_available() opt['confFile'] = conf_file opt['datadir'] = os.path.dirname(conf_file) opt['PREV_ANS_MASK'] = mask opt['PREV_ANS_INDICATOR'] = indicator opt['OFFICIAL'] = True opt['OFFICIAL_TEST_FILE'] = data_file if answer_span_in_context: opt['ANSWER_SPAN_IN_CONTEXT_FEATURE'] = None if no_ans_bit: opt['NO_PREV_ANS_BIT'] = None trainer = SDNetTrainer(opt) test_data = trainer.preproc.preprocess('test') predictions, confidence, final_json = trainer.official( model_file, test_data) with output_path.open(mode='w') as f: json.dump(final_json, f)
parser = argparse.ArgumentParser(description='SDNet') parser.add_argument('--command', default='train', help='Command: train') parser.add_argument('--conf_file', default='conf_stvqa', help='Path to conf file.') parser.add_argument('--log_file', default='', help='Path to log file.') cmdline_args = parser.parse_args() command = cmdline_args.command conf_file = cmdline_args.conf_file conf_args = Arguments(conf_file) opt = conf_args.readArguments() opt['cuda'] = torch.cuda.is_available() opt['confFile'] = conf_file opt['datadir'] = os.path.dirname(conf_file) # conf_file specifies where the data folder is if cmdline_args.log_file != '': if not os.path.exists('myLog'): os.makedirs('myLog') file_handle = logging.FileHandler(os.path.join('myLog', cmdline_args.log_file +'.txt')) log.addHandler(file_handle) for key,val in cmdline_args.__dict__.items(): if val is not None and key not in ['command', 'conf_file']: opt[key] = val model = SDNetTrainer(opt) print('Select command: ' + command) model.train()
from Models.SDNetTrainer import SDNetTrainer from Utils.Arguments import Arguments import logging logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG, datefmt='%m/%d/%Y %I:%M:%S') log = logging.getLogger(__name__) opt = None parser = argparse.ArgumentParser(description='SDNet') parser.add_argument('--command', default='train', help='Command: train') parser.add_argument('--conf_file', default='conf', help='Path to conf file.') cmdline_args = parser.parse_args() command = cmdline_args.command conf_file = cmdline_args.conf_file conf_args = Arguments(conf_file) opt = conf_args.readArguments() opt['cuda'] = torch.cuda.is_available() opt['confFile'] = conf_file opt['datadir'] = os.path.dirname(conf_file) # conf_file specifies where the data folder is for key,val in cmdline_args.__dict__.items(): if val is not None and key not in ['command', 'conf_file']: opt[key] = val model = SDNetTrainer(opt) print('Select command: ' + command) model.predict_for_test()