Exemple #1
0
def train_model(args):
    #
    ds = VOCDataset(root_dir=args.train_dir)
    print('Number of Training Images is: {}'.format(len(ds)))
    scales = args.training_size + 32 * np.array([x for x in range(-5, 6)])
    collater = Collater(scales=scales, keep_ratio=False, multiple=32)
    loader = data.DataLoader(dataset=ds,
                             batch_size=args.batch_size,
                             num_workers=8,
                             collate_fn=collater,
                             shuffle=True,
                             drop_last=True)
    #
    model = STELA(backbone=args.backbone, num_classes=2)
    if os.path.exists(args.pretrained):
        model.load_state_dict(torch.load(args.pretrained))
        print('Load pretrained model from {}.'.format(args.pretrained))
    if torch.cuda.is_available():
        model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step_size,
                                          gamma=0.1)
    iters_per_epoch = np.floor((len(ds) / float(args.batch_size)))
    num_epochs = int(np.ceil(args.max_iter / iters_per_epoch))
    iter_idx = 0
    for _ in range(num_epochs):
        for _, batch in enumerate(loader):
            iter_idx += 1
            if iter_idx > args.max_iter:
                break
            _t.tic()
            scheduler.step(epoch=iter_idx)
            model.train()

            if args.freeze_bn:
                if torch.cuda.device_count() > 1:
                    model.module.freeze_bn()
                else:
                    model.freeze_bn()

            optimizer.zero_grad()
            ims, gt_boxes = batch['image'], batch['boxes']
            if torch.cuda.is_available():
                ims, gt_boxes = ims.cuda(), gt_boxes.cuda()
            losses = model(ims, gt_boxes)
            loss_cls, loss_reg = losses['loss_cls'].mean(
            ), losses['loss_reg'].mean()
            if losses.__contains__('loss_ref'):
                loss_ref = losses['loss_ref'].mean()
                loss = loss_cls + (loss_reg + loss_ref) * 0.5
            else:
                loss = loss_cls + loss_reg
            if bool(loss == 0):
                continue
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            if iter_idx % args.display == 0:
                info = 'iter: [{}/{}], time: {:1.3f}'.format(
                    iter_idx, args.max_iter, _t.toc())
                if losses.__contains__('loss_ref'):
                    info = info + ', ref: {:1.3f}'.format(loss_ref.item())
                info = info + ', cls: {:1.3f}, reg: {:1.3f}'.format(
                    loss_cls.item(), loss_reg.item())
                print(info)
            #
            if (arg.eval_iter > 0) and (iter_idx % arg.eval_iter) == 0:
                model.eval()
                if torch.cuda.device_count() > 1:
                    evaluate(model.module, args)
                else:
                    evaluate(model, args)
    #
    if not os.path.exists('./weights'):
        os.mkdir('./weights')
    if torch.cuda.device_count() > 1:
        torch.save(model.module.state_dict(), './weights/deploy.pth')
    else:
        torch.save(model.state_dict(), './weights/deploy.pth')
def train_model(args):
    # train
    train_dataset = CustomDataset(args.train_img, args.train_gt, args.gt_type_train)
    print('Number of Training Images is: {}'.format(len(train_dataset)))
    scales = args.training_size + 32 * np.array([x for x in range(-5, 6)])
    collater = Collater(scales=scales, keep_ratio=False, multiple=32)
    train_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=8,
        collate_fn=collater,
        shuffle=True,
        drop_last=True
    )
    
    os.makedirs('./weights', exist_ok=True)
    
    if args.gt_type_test == 'json':
        parse_gt_json(args)
    elif args.gt_type_test == 'txt':
        parse_gt_txt(args)
        
    model = STELA(backbone=args.backbone, num_classes=2)
    if os.path.exists(args.pretrained):
        model.load_state_dict(torch.load(args.pretrained))
        print('Load pretrained model from {}.'.format(args.pretrained))
    if torch.cuda.is_available():
        model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)
    iters_per_epoch = np.floor((len(train_dataset) / float(args.batch_size)))
    num_epochs = int(np.ceil(args.max_iter / iters_per_epoch))
    iter_idx = 0
    best_loss = sys.maxsize
    for _ in range(num_epochs):
        for _, batch in enumerate(train_loader):
            iter_idx += 1
            if iter_idx > args.max_iter:
                break
            _t.tic()
            model.train()
            
            if args.freeze_bn:
                if torch.cuda.device_count() > 1:
                    model.module.freeze_bn()
                else:
                    model.freeze_bn()

            optimizer.zero_grad()
            ims, gt_boxes = batch['image'], batch['boxes']
            if torch.cuda.is_available():
                ims, gt_boxes = ims.cuda(), gt_boxes.cuda()
            losses = model(ims, gt_boxes)
            loss_cls, loss_reg = losses['loss_cls'].mean(), losses['loss_reg'].mean()
            if losses.__contains__('loss_ref'):
                loss_ref = losses['loss_ref'].mean()
                loss = loss_cls + (loss_reg + loss_ref) * 0.5
            else:
                loss = loss_cls + loss_reg
            if bool(loss == 0):
                continue
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            optimizer.step()
            scheduler.step()
            
            if iter_idx % args.display == 0:
                info = 'iter: [{}/{}], time: {:1.3f}'.format(iter_idx, args.max_iter, _t.toc())
                if losses.__contains__('loss_ref'):
                    info = info + ', ref: {:1.3f}'.format(loss_ref.item())
                    mlflow_rest.log_metric(metric = {'key': 'loss_ref', 'value': loss_ref.item(), 'step': iter_idx})
                print(info + ', loss_cls: {:1.3f}, loss_reg: {:1.3f}, total_loss: {:1.3f}'.format(loss_cls.item(), loss_reg.item(), loss.item()))                
                if loss.item() < best_loss:
                    best_loss = loss.item()
                    if torch.cuda.device_count() > 1:
                        torch.save(model.module.state_dict(), 'weights/weight_{}_{:1.3f}.pth'.format(iter_idx, loss.item()))
                    else:
                        torch.save(model.state_dict(), 'weights/weight_{}_{:1.3f}.pth'.format(iter_idx, loss.item()))
                mlflow_rest.log_metric(metric = {'key': 'loss_cls', 'value': loss_cls.item(), 'step': iter_idx})
                mlflow_rest.log_metric(metric = {'key': 'loss_reg', 'value': loss_reg.item(), 'step': iter_idx})
                mlflow_rest.log_metric(metric = {'key': 'total_loss', 'value': loss.item(), 'step': iter_idx})
            
#             if (arg.eval_iter > 0) and (iter_idx % arg.eval_iter) == 0:
                
                ## mlflow_rest.log_metric(metric = {'key': 'accuracy', 'value': accuracy, 'step': iter_idx})
                ## mlflow_rest.log_metric(metric = {'key': 'IOU', 'value': avg_iou, 'step': iter_idx})
                ## mlflow_rest.log_metric(metric = {'key': 'confidence', 'value': confidence, 'step': iter_idx})
                ## print('IOU: {}, Score: {}'.format(avg_iou, confidence))
                ## print(f"precision: {precision*100}, recall: {recall*100}, f1: {f1*100}, accuracy: {accuracy*100}")
                
            if iter_idx % args.save_interval == 0:
                if torch.cuda.device_count() > 1:
                    torch.save(model.module.state_dict(), f'weights/check_{iter_idx}.pth')
                else:
                    torch.save(model.state_dict(), f'weights/check_{iter_idx}.pth')
    
    if torch.cuda.device_count() > 1:
        torch.save(model.module.state_dict(), f'weights/final_{args.max_iter}.pth')
    else:
        torch.save(model.state_dict(), f'weights/final_{args.max_iter}.pth')

    model.eval()
    if torch.cuda.device_count() > 1:
        result = evaluate(model.module, args)
    else:
        result = evaluate(model, args)

    mlflow_rest.log_metric(metric = {'key': 'precision', 'value': result['precision'], 'step': iter_idx})
    mlflow_rest.log_metric(metric = {'key': 'recall', 'value': result['recall'], 'step': iter_idx})
    mlflow_rest.log_metric(metric = {'key': 'hmean', 'value': result['hmean'], 'step': iter_idx})