예제 #1
0
    def __init__(self,
                 network,
                 w_lr=0.01,
                 w_mom=0.9,
                 w_wd=1e-4,
                 t_lr=0.001,
                 t_wd=3e-3,
                 t_beta=(0.5, 0.999),
                 init_temperature=5.0,
                 temperature_decay=0.965,
                 logger=logging,
                 lr_scheduler={'T_max': 200},
                 gpus=[0],
                 save_theta_prefix='',
                 theta_result_path='./theta-result',
                 checkpoints_path='./checkpoints'):
        assert isinstance(network, FBNet)
        network.apply(weights_init)
        network = network.train().cuda()
        if isinstance(gpus, str):
            gpus = [int(i) for i in gpus.strip().split(',')]
        network = DataParallel(network, gpus)
        self.gpus = gpus
        self._mod = network
        theta_params = network.theta
        mod_params = network.parameters()
        self.theta = theta_params
        self.w = mod_params
        self._tem_decay = temperature_decay
        self.temp = init_temperature
        self.logger = logger
        self.save_theta_prefix = save_theta_prefix
        if not os.path.exists(theta_result_path):
            os.makedirs(theta_result_path)
        self.theta_result_path = theta_result_path
        if not os.path.exists(checkpoints_path):
            os.makedirs(checkpoints_path)
        self.checkpoints_path = checkpoints_path

        self._acc_avg = AvgrageMeter('acc')
        self._ce_avg = AvgrageMeter('ce')
        self._lat_avg = AvgrageMeter('lat')
        self._loss_avg = AvgrageMeter('loss')

        self.w_opt = torch.optim.SGD(mod_params,
                                     w_lr,
                                     momentum=w_mom,
                                     weight_decay=w_wd)

        self.w_sche = CosineDecayLR(self.w_opt, **lr_scheduler)

        self.t_opt = torch.optim.Adam(theta_params,
                                      lr=t_lr,
                                      betas=t_beta,
                                      weight_decay=t_wd)
예제 #2
0
def main(args):
    # torch.backends.cudnn.benchmark = True
    title = args.title
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/%s_%s_bs_%d_ep_%d" % (
            title, args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print(('checkpoint path: %s' % args.checkpoint))
    print(('init lr: %.8f' % args.lr))
    print(('schedule: ', args.schedule))
    args.vals = args.vals.split(';') if args.vals else []
    print('vals:', args.vals)
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    #data_loader = CTW1500Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale)
    #data_loader = IC15Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale)
    data_loader = OcrDataLoader(args,
                                is_transform=True,
                                img_size=args.img_size,
                                kernel_num=kernel_num,
                                min_scale=min_scale)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)

    if len(args.gpus) > 1:
        model = DataParallel(model,
                             device_ids=args.gpus,
                             chunk_sizes=args.chunk_sizes).cuda()
        optimizer = model.module.optimizer
    else:
        model = model.cuda()
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    # if hasattr(model.module, 'optimizer'):
    #     optimizer = model.module.optimizer
    # else:
    #     optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4)

    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    best_target = {'epoch': 0, 'val': 0}
    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print(('\nEpoch: [%d | %d] LR: %f' %
               (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch)
        # validate
        if args.vals:
            target = run_tests(args, model, epoch)
            # save best model
            if target > best_target['val']:
                best_target['val'] = target
                best_target['epoch'] = epoch + 1
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'lr': args.lr,
                        'optimizer': optimizer.state_dict(),
                    },
                    checkpoint=args.checkpoint,
                    filename='best.pth.tar')
            print('best_target: epoch: %d,  val:%.4f' %
                  (best_target['epoch'], best_target['val']))
        # save latest model
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': args.lr,
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()