示例#1
0
            {
                'dev_matched': data_info.datasets['dev_matched'],
                'dev_mismatched': data_info.datasets['dev_mismatched']
            },
            verbose=1))

trainer = Trainer(train_data=data_info.datasets['train'],
                  model=model,
                  optimizer=optimizer,
                  num_workers=0,
                  batch_size=arg.batch_size,
                  n_epochs=arg.n_epochs,
                  print_every=-1,
                  dev_data=data_info.datasets[arg.devset_name],
                  metrics=AccuracyMetric(pred="pred", target="target"),
                  metric_key='acc',
                  device=[i for i in range(torch.cuda.device_count())],
                  check_code_level=-1,
                  callbacks=callbacks,
                  loss=CrossEntropyLoss(pred="pred", target="target"))
trainer.train(load_best_model=True)

tester = Tester(
    data=data_info.datasets[arg.testset_name],
    model=model,
    metrics=AccuracyMetric(),
    batch_size=arg.batch_size,
    device=[i for i in range(torch.cuda.device_count())],
)
tester.test()
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum)
    trainer = Trainer(datasets['train'],
                      model,
                      optimizer,
                      loss,
                      args.batch,
                      n_epochs=args.epoch,
                      dev_data=datasets['train'],
                      metrics=metrics,
                      device=device,
                      dev_batch_size=args.test_batch)

    trainer.train()
    exit(1208)

bigram_embedding_param = list(model.bigram_embed.parameters())
gaz_embedding_param = list(model.lattice_embed.parameters())
embedding_param = bigram_embedding_param
if args.lattice:
    gaz_embedding_param = list(model.lattice_embed.parameters())
    embedding_param = embedding_param + gaz_embedding_param
embedding_param_ids = list(map(id, embedding_param))
non_embedding_param = list(
    filter(lambda x: id(x) not in embedding_param_ids, model.parameters()))

param_ = [{
    'params': non_embedding_param
}, {
    if not args.debug:
        exit(1208)
datasets['train'].apply
if args.see_convergence:
    print_info('see_convergence = True')
    print_info('so just test train acc|f1')
    datasets['train'] = datasets['train'][:100]
    if args.optim == 'adam':
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    trainer = Trainer(datasets['train'], model, optimizer, loss, args.batch,
                      n_epochs=args.epoch, dev_data=datasets['train'], metrics=metrics,
                      device=device, dev_batch_size=args.test_batch)

    trainer.train()
    exit(1208)

# if args.warmup and args.model == 'transformer':
#     ## warm up start
#     if args.optim == 'adam':
#         warmup_optimizer = optim.AdamW(model.parameters(),lr=args.warmup_lr,weight_decay=args.weight_decay)
#     elif args.optim == 'sgd':
#         warmup_optimizer = optim.SGD(model.parameters(),lr=args.warmup_lr,momentum=args.momentum)
#
#     warmup_lr_schedule = LRScheduler(lr_scheduler=LambdaLR(warmup_optimizer, lambda ep: 1 * (1 + 0.05)**ep))
#     warmup_callbacks = [
#         warmup_lr_schedule,
#     ]
#
#     warmup_trainer = Trainer(datasets['train'],model,warmup_optimizer,loss,args.warmup_batch,