예제 #1
0
    def __init__(self,
                 network,
                 w_lr=0.01,
                 w_mom=0.9,
                 w_wd=1e-4,
                 t_lr=0.001,
                 t_wd=3e-3,
                 t_beta=(0.5, 0.999),
                 init_temperature=5.0,
                 temperature_decay=0.965,
                 logger=logging,
                 lr_scheduler={'T_max': 200},
                 gpus=[0],
                 save_theta_prefix='',
                 theta_result_path='./theta-result',
                 checkpoints_path='./checkpoints'):
        assert isinstance(network, FBNet)
        network.apply(weights_init)
        network = network.train().cuda()
        if isinstance(gpus, str):
            gpus = [int(i) for i in gpus.strip().split(',')]
        network = DataParallel(network, gpus)
        self.gpus = gpus
        self._mod = network
        theta_params = network.theta
        mod_params = network.parameters()
        self.theta = theta_params
        self.w = mod_params
        self._tem_decay = temperature_decay
        self.temp = init_temperature
        self.logger = logger
        self.save_theta_prefix = save_theta_prefix
        if not os.path.exists(theta_result_path):
            os.makedirs(theta_result_path)
        self.theta_result_path = theta_result_path
        if not os.path.exists(checkpoints_path):
            os.makedirs(checkpoints_path)
        self.checkpoints_path = checkpoints_path

        self._acc_avg = AvgrageMeter('acc')
        self._ce_avg = AvgrageMeter('ce')
        self._lat_avg = AvgrageMeter('lat')
        self._loss_avg = AvgrageMeter('loss')

        self.w_opt = torch.optim.SGD(mod_params,
                                     w_lr,
                                     momentum=w_mom,
                                     weight_decay=w_wd)

        self.w_sche = CosineDecayLR(self.w_opt, **lr_scheduler)

        self.t_opt = torch.optim.Adam(theta_params,
                                      lr=t_lr,
                                      betas=t_beta,
                                      weight_decay=t_wd)
예제 #2
0
    def __init__(self, model, optimizer, gpus, chunk_sizes, device):
        self.gpus = gpus
        self.device = device
        self.optimizer = optimizer
        self.loss_stats = ['loss', 'hm_loss', 'wh_loss', 'off_loss']
        self.loss = CtdetLoss()
        self.model_with_loss = ModleWithLoss(model, self.loss)

        if len(self.gpus) > 1:
            self.model_with_loss = DataParallel(self.model_with_loss,
                                                device_ids=self.gpus,
                                                output_device=0,
                                                chunk_sizes=chunk_sizes).to(
                                                    self.device)
        else:
            self.model_with_loss = self.model_with_loss.to(self.device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=self.device, non_blocking=True)
예제 #3
0
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
        pin_memory=True,
        drop_last=True,
        collate_fn=dataset.collate_fn,
    )

    # Replicate Model if n_GPUs > 1
    if opt.n_gpu > 1:
        opt.batch_size *= opt.n_gpu
        device_ids = list(range(torch.cuda.device_count()))
        model = DataParallel(model, device_ids=device_ids, output_device=None)

    optimizer = Ranger(model.parameters(), lr=1e-3, eps=1e-8)

    ## Main Loop
    for epoch in range(opt.epochs):

        model.train()
        start_time = time.time()

        for batch_i, (_, imgs, targets) in enumerate(dataloader):
            '''
            epoch = 0
            iterator = iter(enumerate(dataloader))
            batch_i, (_, imgs, targets) = next(iterator)
            '''
예제 #4
0
                                         collate_fn=collate_minibatch)
# for evaluation
eval_dataset_name = 'minival2014'
eval_dataset = COCOJsonDataset(root=root, annFile=eval_dataset_name, cache_dir=cache_dir)

# Model
net = RetinaNet()
net.load_state_dict(torch.load('../pretrained_model/net.pth'))
if args.resume:
    print('==> Resuming from checkpoint..')
    checkpoint = torch.load('../checkpoint/ckpt_xavier_{}.pth'.format(args.epoch))
    net.load_state_dict(checkpoint['net'])
    best_loss = checkpoint['loss']
    start_epoch = checkpoint['epoch']

net = DataParallel(net, device_ids=range(torch.cuda.device_count()), minibatch=True)
net.cuda()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)


# training
def train(epoch):
    print('\nEpoch: {}'.format(epoch))
    logger.debug('\nEpoch: {}'.format(epoch))
    net.train()
    net.module.freeze_bn()
    train_loss = 0
    for batch_idx, blobs in enumerate(dataloader):
        inputs, loc_targets, cls_targets = blobs['data'], blobs['loc_targets'], blobs['cls_targets']
        inputs = list(map(Variable, inputs))
        loc_targets = list(map(Variable, loc_targets))
예제 #5
0
def main(args):
    # torch.backends.cudnn.benchmark = True
    title = args.title
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/%s_%s_bs_%d_ep_%d" % (
            title, args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print(('checkpoint path: %s' % args.checkpoint))
    print(('init lr: %.8f' % args.lr))
    print(('schedule: ', args.schedule))
    args.vals = args.vals.split(';') if args.vals else []
    print('vals:', args.vals)
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    #data_loader = CTW1500Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale)
    #data_loader = IC15Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale)
    data_loader = OcrDataLoader(args,
                                is_transform=True,
                                img_size=args.img_size,
                                kernel_num=kernel_num,
                                min_scale=min_scale)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               drop_last=True,
                                               pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)

    if len(args.gpus) > 1:
        model = DataParallel(model,
                             device_ids=args.gpus,
                             chunk_sizes=args.chunk_sizes).cuda()
        optimizer = model.module.optimizer
    else:
        model = model.cuda()
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    # if hasattr(model.module, 'optimizer'):
    #     optimizer = model.module.optimizer
    # else:
    #     optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4)

    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    best_target = {'epoch': 0, 'val': 0}
    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print(('\nEpoch: [%d | %d] LR: %f' %
               (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch)
        # validate
        if args.vals:
            target = run_tests(args, model, epoch)
            # save best model
            if target > best_target['val']:
                best_target['val'] = target
                best_target['epoch'] = epoch + 1
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'lr': args.lr,
                        'optimizer': optimizer.state_dict(),
                    },
                    checkpoint=args.checkpoint,
                    filename='best.pth.tar')
            print('best_target: epoch: %d,  val:%.4f' %
                  (best_target['epoch'], best_target['val']))
        # save latest model
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': args.lr,
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()
예제 #6
0
class Trainer(object):
    def __init__(self, model, optimizer, gpus, chunk_sizes, device):
        self.gpus = gpus
        self.device = device
        self.optimizer = optimizer
        self.loss_stats = ['loss', 'hm_loss', 'wh_loss', 'off_loss']
        self.loss = CtdetLoss()
        self.model_with_loss = ModleWithLoss(model, self.loss)

        if len(self.gpus) > 1:
            self.model_with_loss = DataParallel(self.model_with_loss,
                                                device_ids=self.gpus,
                                                output_device=0,
                                                chunk_sizes=chunk_sizes).to(
                                                    self.device)
        else:
            self.model_with_loss = self.model_with_loss.to(self.device)

        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=self.device, non_blocking=True)

    def run_epoch(self, phase, epoch, data_loader):
        model_with_loss = self.model_with_loss
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.gpus) > 1:
                model_with_loss = self.model_with_loss.module
            model_with_loss.eval()
            torch.cuda.empty_cache()

        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader)
        bar = Bar('{}'.format('ctdet'), max=num_iters)
        end = time.time()
        for iter_id, batch in enumerate(data_loader):
            if iter_id >= num_iters:
                break
            data_time.update(time.time() - end)

            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=self.device,
                                           non_blocking=True)
            output, loss, loss_stats = model_with_loss(batch)
            loss = loss.mean()
            if phase == 'train':
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            batch_time.update(time.time() - end)
            end = time.time()

            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            for l in avg_loss_stats:
                avg_loss_stats[l].update(loss_stats[l].mean().item(),
                                         batch['input'].size(0))
                Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(
                    l, avg_loss_stats[l].avg)
            Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) |Net {bt.avg:.3f}s'.format(
                dt=data_time, bt=batch_time)
            bar.next()
            del output, loss, loss_stats

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        return ret, results