args = parser.parse_args()
if args.ff_dropout_2 < 0:
    args.ff_dropout_2 = args.ff_dropout

if over_all_dropout > 0:
    args.embed_dropout = over_all_dropout
    args.output_dropout = over_all_dropout
    args.pre_dropout = over_all_dropout
    args.post_dropout = over_all_dropout
    args.ff_dropout = over_all_dropout
    args.attn_dropout = over_all_dropout

if args.lattice and args.use_rel_pos and args.update_every == 1:
    args.train_clip = True

now_time = get_peking_time()
logger.add_file('log/{}'.format(now_time), level='info')
if args.test_batch == -1:
    args.test_batch = args.batch // 2
fitlog.add_hyper(now_time, 'time')
if args.debug:
    # args.dataset = 'toy'
    pass

if args.device != 'cpu':
    assert args.device.isdigit()
    device = torch.device('cuda:{}'.format(args.device))
else:
    device = torch.device('cpu')

refresh_data = True
Example #2
0
    logging.info('left_chars={}'.format(left_chars))
    logging.info('right_chars={}'.format(right_chars))
    logging.info('entity_chars={}'.format(entity_chars))
    logging.info('entity_label={}'.format(entity_label))
    logging.info('target={}'.format(target))

    model.load_state_dict(torch.load(save_model_path))
    model = model.cpu()
    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)

    with torch.no_grad():
        traced_script_module = torch.jit.script(
            model, (left_chars, right_chars, entity_chars, entity_label))
        logging.info('model output: {}'.format(
            traced_script_module.forward(left_chars, right_chars, entity_chars,
                                         entity_label)))
        traced_script_module.save("ner_clf_script.pt")

if args.status == 'test':
    device = torch.device('cpu')
    #device = torch.device('cuda:0')
    model.load_state_dict(torch.load(save_model_path, map_location=device))

    from fastNLP import Tester
    tester = Tester(datasets['test'], model, metrics=metrics, batch_size=1)
    logging.info('test start={}'.format(get_peking_time()))
    tester.test()
    logging.info('test end={}'.format(get_peking_time()))