예제 #1
0
def main(args):
    # Build data loader
    tr_loader = build_data_loader(args.train_data,
                                  args.vocab,
                                  args.punc_vocab,
                                  batch_size=args.batch_size,
                                  drop_last=False,
                                  num_workers=args.num_workers)
    cv_loader = build_data_loader(args.valid_data,
                                  args.vocab,
                                  args.punc_vocab,
                                  batch_size=args.batch_size,
                                  drop_last=False)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # Build model
    model = LstmPunctuator(args.num_embeddings, args.embedding_dim,
                           args.hidden_size, args.num_layers,
                           args.bidirectional, args.num_class)
    print(model)
    print("Number of parameters: %d" % num_param(model))
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # Build criterion
    criterion = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_ID)
    # Build optimizer
    optimizier = torch.optim.Adam(model.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.l2)
    # Build Solver
    solver = Solver(data, model, criterion, optimizier, args)
    solver.train()
예제 #2
0
def main(args):
    cfg = Config.fromfile(args.config)
    for d in [cfg, cfg.data.test]:
        d.update(dict(report_speed=args.report_speed))
    print(json.dumps(cfg._cfg_dict, indent=4))
    sys.stdout.flush()

    # data loader
    data_loader = build_data_loader(cfg.data.test)
    test_loader = torch.utils.data.DataLoader(
        data_loader,
        batch_size=1,
        shuffle=False,
        num_workers=2,
    )
    # model
    if hasattr(cfg.model, 'recognition_head'):
        cfg.model.recognition_head.update(
            dict(
                voc=data_loader.voc,
                char2id=data_loader.char2id,
                id2char=data_loader.id2char,
            ))
    model = build_model(cfg.model)
    model = model.cuda()

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.checkpoint))
            sys.stdout.flush()

            checkpoint = torch.load(args.checkpoint)

            d = dict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            raise

    # fuse conv and bn
    model = fuse_module(model)
    model_structure(model)
    # test
    test(test_loader, model, cfg)
예제 #3
0
def main(args):
    cfg = Config.fromfile(args.config)
    for d in [cfg, cfg.data.test]:
        d.update(dict(report_speed=args.report_speed))
    print(json.dumps(cfg._cfg_dict, indent=4))
    sys.stdout.flush()

    # data loader
    data_loader = build_data_loader(cfg.data.test)
    test_loader = paddle.io.DataLoader(data_loader,
                                       batch_size=1,
                                       shuffle=False,
                                       num_workers=0,
                                       use_shared_memory=True)

    device = paddle.get_device()
    paddle.set_device(device)

    # model
    model = build_model(cfg.model)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.checkpoint))
            sys.stdout.flush()

            checkpoint = paddle.load(args.checkpoint)
            model.set_state_dict(checkpoint)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            raise

    # fuse conv and bn
    model = fuse_module(model)

    # test
    test(test_loader, model, cfg)
예제 #4
0
def main(args):
    # get device
    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # build model
    model = STTR(args).to(device)
    print_param(model)

    # set learning rate
    param_dicts = [
        {"params": [p for n, p in model.named_parameters() if
                    "backbone" not in n and "regression" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
        {
            "params": [p for n, p in model.named_parameters() if "regression" in n and p.requires_grad],
            "lr": args.lr_regression,
        },
    ]

    # define optimizer and learning rate scheduler
    optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay_rate)

    # mixed precision training
    if args.apex:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None

    # load checkpoint if provided
    prev_best = np.inf
    if args.resume != '':
        if not os.path.isfile(args.resume):
            raise RuntimeError(f"=> no checkpoint found at '{args.resume}'")
        checkpoint = torch.load(args.resume)

        pretrained_dict = checkpoint['state_dict']
        model.load_state_dict(pretrained_dict)
        print("Pre-trained model successfully loaded.")

        # if not ft/inference/eval, load states for optimizer, lr_scheduler, amp and prev best
        if not (args.ft or args.inference or args.eval):
            args.start_epoch = checkpoint['epoch'] + 1
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            prev_best = checkpoint['best_pred']
            if args.apex:
                amp.load_state_dict(checkpoint['amp'])
            print("Pre-trained optimizer, lr scheduler and stats successfully loaded.")

    # inference
    if args.inference:
        print("Start inference")
        _, _, data_loader = build_data_loader(args)
        inference(model, data_loader, device, args.downsample)

        return

    # initiate saver and logger
    checkpoint_saver = Saver(args)
    summary_writer = TensorboardSummary(checkpoint_saver.experiment_dir)

    # build dataloader
    data_loader_train, data_loader_val, _ = build_data_loader(args)

    # build loss criterion
    criterion = build_criterion(args)

    # set downsample rate
    set_downsample(args)

    # eval
    if args.eval:
        print("Start evaluation")
        evaluate(model, criterion, data_loader_val, device, 0, summary_writer, True)
        return

    # train
    print("Start training")
    for epoch in range(args.start_epoch, args.epochs):
        # train
        print("Epoch: %d" % epoch)
        train_one_epoch(model, data_loader_train, optimizer, criterion, device, epoch, summary_writer,
                        args.clip_max_norm, amp)

        # step lr if not pretraining
        if not args.pre_train:
            lr_scheduler.step()
            print("current learning rate", lr_scheduler.get_lr())

        # empty cache
        torch.cuda.empty_cache()

        # save if pretrain, save every 50 epochs
        if args.pre_train or epoch % 50 == 0:
            save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False, amp)

        # validate
        eval_stats = evaluate(model, criterion, data_loader_val, device, epoch, summary_writer, False)
        # save if best
        if prev_best > eval_stats['epe'] and 0.5 > eval_stats['px_error_rate']:
            save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, True, amp)

    # save final model
    save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False, amp)

    return
예제 #5
0
def main(args):
    cfg = Config.fromfile(args.config)
    print(json.dumps(cfg._cfg_dict, indent=4))

    if args.checkpoint is not None:
        checkpoint_path = args.checkpoint
    else:
        cfg_name, _ = osp.splitext(osp.basename(args.config))
        checkpoint_path = osp.join('checkpoints', cfg_name)
    if not osp.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)
    print('Checkpoint path: %s.' % checkpoint_path)
    sys.stdout.flush()

    # data loader
    data_loader = build_data_loader(cfg.data.train)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=cfg.data.batch_size,
                                               shuffle=True,
                                               num_workers=8,
                                               drop_last=True,
                                               pin_memory=True)

    # model
    model = build_model(cfg.model)
    model = torch.nn.DataParallel(model).cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        if cfg.train_cfg.optimizer == 'SGD':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=cfg.train_cfg.lr,
                                        momentum=0.99,
                                        weight_decay=5e-4)
        elif cfg.train_cfg.optimizer == 'Adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=cfg.train_cfg.lr)

    start_epoch = 0
    start_iter = 0
    if hasattr(cfg.train_cfg, 'pretrain'):
        assert osp.isfile(
            cfg.train_cfg.pretrain), 'Error: no pretrained weights found!'
        print('Finetuning from pretrained model %s.' % cfg.train_cfg.pretrain)
        checkpoint = torch.load(cfg.train_cfg.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
    if args.resume:
        assert osp.isfile(args.resume), 'Error: no checkpoint directory found!'
        print('Resuming from checkpoint %s.' % args.resume)
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        start_iter = checkpoint['iter']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    for epoch in range(start_epoch, cfg.train_cfg.epoch):
        print('\nEpoch: [%d | %d]' % (epoch + 1, cfg.train_cfg.epoch))

        if args.record_file is None:
            train(train_loader, model, optimizer, epoch, start_iter, cfg)
        else:
            output_log = train(train_loader, model, optimizer, epoch,
                               start_iter, cfg)
            output_log = 'Epoch: [{cur_epoch:d} | {total_epoch:d}]  '.format(
                cur_epoch=epoch + 1,
                total_epoch=cfg.train_cfg.epoch) + output_log
            with open(args.record_file, 'a') as f:
                f.write(output_log)
                f.close()

        state = dict(epoch=epoch + 1,
                     iter=0,
                     state_dict=model.state_dict(),
                     optimizer=optimizer.state_dict())
        save_checkpoint(state, checkpoint_path, cfg)
예제 #6
0
def main(args):
    cfg = Config.fromfile(args.config)
    print(json.dumps(cfg._cfg_dict, indent=4))

    if args.checkpoint is not None:
        checkpoint_path = args.checkpoint
    else:
        cfg_name, _ = osp.splitext(osp.basename(args.config))
        checkpoint_path = osp.join('checkpoints', cfg_name)
    if not osp.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)
    print('Checkpoint path: %s.' % checkpoint_path)
    sys.stdout.flush()

    # data loader
    data_loader = build_data_loader(cfg.data.train)
    train_loader = paddle.io.DataLoader(data_loader,
                                        batch_size=cfg.data.batch_size,
                                        shuffle=True,
                                        num_workers=0,
                                        drop_last=True,
                                        use_shared_memory=True)

    # device
    device = paddle.get_device()
    paddle.set_device(device)

    # model
    model = build_model(cfg.model)
    model = paddle.DataParallel(model)

    # Check if model has custom optimizer / loss
    optimizer = None
    if hasattr(model, 'optimizer'):
        optimizer = model.optimizer
    else:
        if cfg.train_cfg.optimizer == 'SGD':
            optimizer = paddle.optimizer.Momentum(
                parameters=model.parameters(),
                learning_rate=cfg.train_cfg.lr,
                weight_decay=5e-4,
                momentum=0.99)

        elif cfg.train_cfg.optimizer == 'Adam':
            optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
                                              learning_rate=cfg.train_cfg.lr)

    start_epoch = 0
    start_iter = 0
    if hasattr(cfg.train_cfg, 'pretrain'):
        assert osp.isfile(
            cfg.train_cfg.pretrain), 'Error: no pretrained weights found!'
        print('Finetuning from pretrained model %s.' % cfg.train_cfg.pretrain)
        checkpoint = paddle.load(cfg.train_cfg.pretrain)
        model.set_state_dict(checkpoint)
    if args.resume:
        cfg_name, _ = osp.splitext(osp.basename(args.config))
        checkpoint_path = osp.join('checkpoints', cfg_name)
        pdparams_file = checkpoint_path + "/" + args.resume + ".pdparams"
        pdopt_file = checkpoint_path + "/" + args.resume + ".pdopt"
        assert osp.isfile(
            pdparams_file
        ), 'Error: no checkpoint pdparams file directory found!'
        print('Resuming from checkpoint %s.' % pdparams_file)
        assert osp.isfile(
            pdopt_file), 'Error: no checkpoint pdopt file directory found!'
        print('Resuming from checkpoint %s.' % pdopt_file)
        start_epoch = int(str(args.resume).split("_")[1])
        start_iter = int(str(args.resume).split("_")[2])
        print("start epoch: ", start_epoch, "   start iter: ", start_iter)
        checkpoint = paddle.load(pdparams_file)
        model.set_state_dict(checkpoint)
        checkpoint = paddle.load(pdopt_file)
        optimizer.set_state_dict(checkpoint)

    for epoch in range(start_epoch, cfg.train_cfg.epoch):
        print('\nEpoch: [%d | %d]' % (epoch + 1, cfg.train_cfg.epoch))

        train(train_loader, model, optimizer, epoch, start_iter, cfg)

        state = dict(epoch=epoch + 1,
                     iter=0,
                     state_dict=model.state_dict(),
                     optimizer=optimizer.state_dict())
        if epoch % 1 == 0:
            save_checkpoint(state, checkpoint_path, cfg)