Exemplo n.º 1
0
def create_cb():
    lrschedule_callback = LRScheduler(
        lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05 * ep)))
    clip_callback = GradientClipCallback(clip_type='value', clip_value=2)
    save_dir = os.path.join(root_path, f'model/{args.data_type}',
                            f'fold{args.fold}')
    save_callback = SaveModelCallback(top=1, save_dir=save_dir)
    if args.cv:
        callbacks = [
            lrschedule_callback,
            clip_callback,
            save_callback,
        ]
    else:
        callbacks = [
            lrschedule_callback,
            clip_callback,
            save_callback,
        ]
    # callbacks.append(Unfreeze_Callback(embedding_param ,args.fix_embed_epoch))

    if args.use_bert:
        if args.fix_bert_epoch != 0:
            callbacks.append(
                Unfreeze_Callback(model.lattice_embed, args.fix_bert_epoch))
        else:
            bert_embedding.requires_grad = True

    callbacks.append(EarlyStopCallback(args.early_stop))

    if args.warmup > 0 and args.model == 'transformer':
        callbacks.append(WarmupCallback(warmup=args.warmup, ))
    return callbacks
elif args.optim == 'sgd':
    # optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,
    #                       weight_decay=args.weight_decay)
    optimizer = optim.SGD(param_,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

if 'msra' in args.dataset:
    datasets['dev'] = datasets['test']

fitlog_evaluate_dataset = {'test': datasets['test']}
if args.test_train:
    fitlog_evaluate_dataset['train'] = datasets['train']
evaluate_callback = FitlogCallback(fitlog_evaluate_dataset, verbose=1)
lrschedule_callback = LRScheduler(
    lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05 * ep)))
clip_callback = GradientClipCallback(clip_type='value', clip_value=5)


# model.state_dict()
class CheckWeightCallback(Callback):
    def __init__(self, model):
        super().__init__()
        self.model_ = model

    def on_step_end(self):
        print('parameter weight:', flush=True)
        print(self.model_.state_dict()['encoder.layer_0.attn.w_q.weight'],
              flush=True)

Exemplo n.º 3
0
def train_mlt_single(args):
    global logger
    logger.info(args)
    task_lst, vocabs = utils.get_data(args.data_path)
    task_db = task_lst[args.task_id]
    train_data = task_db.train_set
    dev_data = task_db.dev_set
    test_data = task_db.test_set
    task_name = task_db.task_name

    if args.debug:
        train_data = train_data[:200]
        dev_data = dev_data[:200]
        test_data = test_data[:200]
        args.epochs = 3
        args.pruning_iter = 3

    summary_writer = SummaryWriter(
        log_dir=os.path.join(args.tb_path, "global/%s" % task_name)
    )

    logger.info("task name: {}, task id: {}".format(task_db.task_name, task_db.task_id))
    logger.info(
        "train len {}, dev len {}, test len {}".format(
            len(train_data), len(dev_data), len(test_data)
        )
    )

    # init model
    model = get_model(args, task_lst, vocabs)

    logger.info("model: \n{}".format(model))
    if args.init_weights is not None:
        utils.load_model(model, args.init_weights)

    if utils.need_acc(task_name):
        metrics = [AccuracyMetric(target="y"), MetricInForward(val_name="loss")]
        metric_key = "acc"

    else:
        metrics = [
            YangJieSpanMetric(
                tag_vocab=vocabs[task_name],
                pred="pred",
                target="y",
                seq_len="seq_len",
                encoding_type="bioes" if task_name == "ner" else "bio",
            ),
            MetricInForward(val_name="loss"),
        ]
        metric_key = "f"
    logger.info(metrics)

    need_cut_names = list(set([s.strip() for s in args.need_cut.split(",")]))
    prune_names = []
    for name, p in model.named_parameters():
        if not p.requires_grad or "bias" in name:
            continue
        for n in need_cut_names:
            if n in name:
                prune_names.append(name)
                break

    # get Pruning class
    pruner = Pruning(
        model, prune_names, final_rate=args.final_rate, pruning_iter=args.pruning_iter
    )
    if args.init_masks is not None:
        pruner.load(args.init_masks)
        pruner.apply_mask(pruner.remain_mask, pruner._model)
    # save checkpoint
    os.makedirs(args.save_path, exist_ok=True)

    logger.info('Saving init-weights to {}'.format(args.save_path))
    torch.save(
        model.cpu().state_dict(), os.path.join(args.save_path, "init_weights.th")
    )
    torch.save(args, os.path.join(args.save_path, "args.th"))
    # start training and pruning
    summary_writer.add_scalar("remain_rate", 100.0, 0)
    summary_writer.add_scalar("cutoff", 0.0, 0)

    if args.init_weights is not None:
        init_tester = Tester(
            test_data,
            model,
            metrics=metrics,
            batch_size=args.batch_size,
            num_workers=4,
            device="cuda",
            use_tqdm=False,
        )
        res = init_tester.test()
        logger.info("No init testing, Result: {}".format(res))
        del res, init_tester

    for prune_step in range(pruner.pruning_iter + 1):
        # reset optimizer every time
        optim_params = [p for p in model.parameters() if p.requires_grad]
        # utils.get_logger(__name__).debug(optim_params)
        utils.get_logger(__name__).debug(len(optim_params))
        optimizer = get_optim(args.optim, optim_params)
        # optimizer = TriOptim(optimizer, args.n_filters, args.warmup, args.decay)
        factor = pruner.cur_rate / 100.0
        factor = 1.0
        # print(factor, pruner.cur_rate)
        for pg in optimizer.param_groups:
            pg["lr"] = factor * pg["lr"]
        utils.get_logger(__name__).info(optimizer)

        trainer = Trainer(
            train_data,
            model,
            loss=LossInForward(),
            optimizer=optimizer,
            metric_key=metric_key,
            metrics=metrics,
            print_every=200,
            batch_size=args.batch_size,
            num_workers=4,
            n_epochs=args.epochs,
            dev_data=dev_data,
            save_path=None,
            sampler=fastNLP.BucketSampler(batch_size=args.batch_size),
            callbacks=[
                pruner,
                # LRStep(lstm.WarmupLinearSchedule(optimizer, args.warmup, int(len(train_data)/args.batch_size*args.epochs)))
                GradientClipCallback(clip_type="norm", clip_value=5),
                LRScheduler(
                    lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05 * ep))
                ),
                LogCallback(path=os.path.join(args.tb_path, "No", str(prune_step))),
            ],
            use_tqdm=False,
            device="cuda",
            check_code_level=-1,
        )
        res = trainer.train()
        logger.info("No #{} training, Result: {}".format(pruner.prune_times, res))
        name, val = get_metric(res)
        summary_writer.add_scalar("prunning_dev_acc", val, prune_step)
        tester = Tester(
            test_data,
            model,
            metrics=metrics,
            batch_size=args.batch_size,
            num_workers=4,
            device="cuda",
            use_tqdm=False,
        )
        res = tester.test()
        logger.info("No #{} testing, Result: {}".format(pruner.prune_times, res))
        name, val = get_metric(res)
        summary_writer.add_scalar("pruning_test_acc", val, prune_step)

        # prune and save
        torch.save(
            model.state_dict(),
            os.path.join(
                args.save_path,
                "best_{}_{}.th".format(pruner.prune_times, pruner.cur_rate),
            ),
        )
        pruner.pruning_model()
        summary_writer.add_scalar("remain_rate", pruner.cur_rate, prune_step + 1)
        summary_writer.add_scalar("cutoff", pruner.last_cutoff, prune_step + 1)

        pruner.save(
            os.path.join(
                args.save_path, "{}_{}.th".format(pruner.prune_times, pruner.cur_rate)
            )
        )
Exemplo n.º 4
0
metric2 = CWSMetric(char_labels_vocab['APP'])
metrics = [metric1, metric2]

optimizer = optim.Adam(
    [param for param in model.parameters() if param.requires_grad],
    lr=lr,
    weight_decay=weight_decay,
    betas=[0.9, 0.9])

sampler = BucketSampler(seq_len_field_name='seq_lens')
callbacks = []
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every)
# callbacks.append(optim_callback)
scheduler_callback = LRScheduler(scheduler)
callbacks.append(scheduler_callback)
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))

tester = Tester(data=data.datasets['test'],
                model=model,
                metrics=metrics,
                batch_size=64,
                device=device,
                verbose=0)
dev_callback = DevCallback(tester)
callbacks.append(dev_callback)

trainer = Trainer(data.datasets['train'],
                  model,
                  loss=None,
Exemplo n.º 5
0
acc_metric = AccuracyMetric(pred='pred', target='target', seq_len='seq_len')
metrics = [f1_metric, acc_metric]

if args.optim == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
elif args.optim == 'sgd':
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

callbacks = [
    FitlogCallback({
        'test': datasets['test'],
        'train': datasets['train']
    }),
    LRScheduler(
        lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep))
]
print('label_vocab:{}\n{}'.format(len(vocabs['label']),
                                  vocabs['label'].idx2word))
trainer = Trainer(datasets['train'],
                  model,
                  optimizer=optimizer,
                  loss=loss,
                  metrics=metrics,
                  dev_data=datasets['dev'],
                  device=device,
                  batch_size=args.batch,
                  n_epochs=args.epoch,
                  dev_batch_size=args.test_batch,
                  callbacks=callbacks)