Ejemplo n.º 1
0
def train(args):
    dataset = args.dataset
    data_path = args.data_path
    class_path = args.class_path
    checkpoint_path = args.checkpoint_path

    input_height = args.input_height
    input_width = args.input_width
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    lr = args.lr
    weight_decay = args.weight_decay
    dropout = args.dropout
    l_coord = args.l_coord
    l_noobj = args.l_noobj
    num_gpus = [i for i in range(args.num_gpus)]
    num_class = args.num_class

    USE_AUGMENTATION = args.use_augmentation
    #     USE_VISDOM = args.use_visdom
    #     USE_WANDB = args.use_wandb
    USE_SUMMARY = args.use_summary

    if USE_AUGMENTATION:
        seq = iaa.SomeOf(2, [
            iaa.Multiply((1.2, 1.5)),
            iaa.Affine(translate_px={
                "x": 3,
                "y": 10
            }, scale=(0.9, 0.9)),
            iaa.AdditiveGaussianNoise(scale=0.1 * 255),
            iaa.CoarseDropout(0.02, size_percent=0.15, per_channel=0.5),
            iaa.Affine(rotate=45),
            iaa.Sharpen(alpha=0.5)
        ])
    else:
        seq = iaa.Sequential([])

    composed = transforms.Compose([Augmenter(seq)])

    # DataLoader

    train_dataset = VOC(root=data_path,
                        transform=composed,
                        class_path=class_path)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               collate_fn=detection_collate)

    # model

    model = models.YOLOv1(num_class, dropout)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model, device_ids=num_gpus).to(device)
    else:
        model = torch.nn.DataParallel(model)

    if USE_SUMMARY:
        summary(model, (3, 448, 448))

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

    total_step = 0
    #     total_train_step = num_epochs * total_step

    for epoch in range(1, num_epochs + 1):
        if (epoch == 200) or (epoch == 400) or (epoch == 600) or (
                epoch == 20000) or (epoch == 30000):
            scheduler.step()

        for i, (images, labels, sizes) in enumerate(train_loader):

            total_step += 1
            images = images.to(device)
            labels = labels.to(device)

            pred = model(images)

            loss, losses = detection_loss_4_yolo(pred, labels, l_coord,
                                                 l_noobj, device)

            coord_loss = losses[0]
            size_loss = losses[1]
            objness_loss = losses[2]
            noobjness_loss = losses[3]
            class_loss = losses[4]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if total_step % 100 == 0:
                print("epoch: [{}/{}], step:{}, lr:{}, total_loss:{:.4f}, \
                    \ncoord:{:.4f}, size:{:.4f}, objness:{:.4f}, noobjness:{:.4f}, class:{:.4f}"
                      .format(epoch, num_epochs, total_step, ([
                          param['lr'] for param in optimizer.param_groups
                      ])[0], loss.item(), coord_loss, size_loss, objness_loss,
                              noobjness_loss, class_loss))

            if epoch % 1000 == 0:
                save_checkpoint(
                    {
                        "epoch": epoch,
                        "arch": "YoloV1",
                        "state_dict": model.state.dict(),
                        "optimizer": optimizer.state.dict()
                    },
                    False,
                    filename=os.path.join(
                        checkpoint_path,
                        "ckpt_ep{:.05d}_loss{:.04f}_lr{}.pth.tar".format(
                            epoch, loss.item(),
                            ([param['lr']
                              for param in optimizer.param_group])[0])))
Ejemplo n.º 2
0
def train(params):

    # future work variable
    dataset = params["dataset"]
    input_height = params["input_height"]
    input_width = params["input_width"]

    data_path = params["data_path"]
    class_path = params["class_path"]
    batch_size = params["batch_size"]
    num_epochs = params["num_epochs"]
    learning_rate = params["lr"]
    dropout = params["dropout"]
    num_gpus = [i for i in range(params["num_gpus"])]
    checkpoint_path = params["checkpoint_path"]

    USE_VISDOM = params["use_visdom"]
    USE_WANDB = params["use_wandb"]
    USE_SUMMARY = params["use_summary"]
    USE_AUGMENTATION = params["use_augmentation"]
    USE_GTCHECKER = params["use_gtcheck"]

    USE_GITHASH = params["use_githash"]
    num_class = params["num_class"]

    if (USE_WANDB):
        wandb.init()
        wandb.config.update(
            params)  # adds all of the arguments as config variables

    with open(class_path) as f:
        class_list = f.read().splitlines()

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    if (USE_GITHASH):
        repo = git.Repo(search_parent_directories=True)
        sha = repo.head.object.hexsha
        short_sha = repo.git.rev_parse(sha, short=7)

    if USE_VISDOM:
        viz = visdom.Visdom(use_incoming_socket=False)
        vis_title = 'Yolo V1 Deepbaksu_vision (feat. martin, visionNoob) PyTorch on ' + 'VOC'
        vis_legend = ['Train Loss']
        iter_plot = create_vis_plot(viz, 'Iteration', 'Total Loss', vis_title,
                                    vis_legend)
        coord1_plot = create_vis_plot(viz, 'Iteration', 'coord1', vis_title,
                                      vis_legend)
        size1_plot = create_vis_plot(viz, 'Iteration', 'size1', vis_title,
                                     vis_legend)
        noobjectness1_plot = create_vis_plot(viz, 'Iteration', 'noobjectness1',
                                             vis_title, vis_legend)
        objectness1_plot = create_vis_plot(viz, 'Iteration', 'objectness1',
                                           vis_title, vis_legend)
        obj_cls_plot = create_vis_plot(viz, 'Iteration', 'obj_cls', vis_title,
                                       vis_legend)

    # 2. Data augmentation setting
    if (USE_AUGMENTATION):
        seq = iaa.SomeOf(
            2,
            [
                iaa.Multiply(
                    (1.2, 1.5)),  # change brightness, doesn't affect BBs
                iaa.Affine(
                    translate_px={
                        "x": 3,
                        "y": 10
                    }, scale=(0.9, 0.9)
                ),  # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs
                iaa.AdditiveGaussianNoise(scale=0.1 * 255),
                iaa.CoarseDropout(0.02, size_percent=0.15, per_channel=0.5),
                iaa.Affine(rotate=45),
                iaa.Sharpen(alpha=0.5)
            ])
    else:
        seq = iaa.Sequential([])

    composed = transforms.Compose([Augmenter(seq)])

    # 3. Load Dataset
    # composed
    # transforms.ToTensor
    train_dataset = VOC(root=data_path,
                        transform=composed,
                        class_path=class_path)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               collate_fn=detection_collate)

    # 5. Load YOLOv1
    net = yolov1.YOLOv1(params={"dropout": dropout, "num_class": num_class})
    # model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda()

    print("device : ", device)
    if device.type == 'cpu':
        model = torch.nn.DataParallel(net)
    else:
        model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda()

    if USE_SUMMARY:
        summary(model, (3, 448, 448))

    # 7.Train the model
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

    # Train the model
    total_step = len(train_loader)

    total_train_step = num_epochs * total_step

    # for epoch in range(num_epochs):
    for epoch in range(1, num_epochs + 1):

        if (epoch == 200) or (epoch == 400) or (epoch == 600) or (
                epoch == 20000) or (epoch == 30000):
            scheduler.step()

        for i, (images, labels, sizes) in enumerate(train_loader):

            current_train_step = (epoch) * total_step + (i + 1)

            if USE_GTCHECKER:
                visualize_GT(images, labels, class_list)

            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Calc Loss
            loss, \
            obj_coord1_loss, \
            obj_size1_loss, \
            obj_class_loss, \
            noobjness1_loss, \
            objness1_loss = detection_loss_4_yolo(outputs, labels, device.type)
            # objness1_loss = detection_loss_4_yolo(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (((current_train_step) % 100)
                    == 0) or (current_train_step % 10 == 0
                              and current_train_step < 100):
                print(
                    'epoch: [{}/{}], total step: [{}/{}], batch step [{}/{}], lr: {}, total_loss: {:.4f}, coord1: {:.4f}, size1: {:.4f}, noobj_clss: {:.4f}, objness1: {:.4f}, class_loss: {:.4f}'
                    .format(epoch + 1, num_epochs, current_train_step,
                            total_train_step, i + 1, total_step, ([
                                param_group['lr']
                                for param_group in optimizer.param_groups
                            ])[0], loss.item(), obj_coord1_loss,
                            obj_size1_loss, noobjness1_loss, objness1_loss,
                            obj_class_loss))

                if USE_VISDOM:
                    update_vis_plot(viz, (epoch + 1) * total_step + (i + 1),
                                    loss.item(), iter_plot, None, 'append')
                    update_vis_plot(viz, (epoch + 1) * total_step + (i + 1),
                                    obj_coord1_loss, coord1_plot, None,
                                    'append')
                    update_vis_plot(viz, (epoch + 1) * total_step + (i + 1),
                                    obj_size1_loss, size1_plot, None, 'append')
                    update_vis_plot(viz, (epoch + 1) * total_step + (i + 1),
                                    obj_class_loss, obj_cls_plot, None,
                                    'append')
                    update_vis_plot(viz, (epoch + 1) * total_step + (i + 1),
                                    noobjness1_loss, noobjectness1_plot, None,
                                    'append')
                    update_vis_plot(viz, (epoch + 1) * total_step + (i + 1),
                                    objness1_loss, objectness1_plot, None,
                                    'append')

                if USE_WANDB:
                    wandb.log({
                        'total_loss': loss.item(),
                        'obj_coord1_loss': obj_coord1_loss,
                        'obj_size1_loss': obj_size1_loss,
                        'obj_class_loss': obj_class_loss,
                        'noobjness1_loss': noobjness1_loss,
                        'objness1_loss': objness1_loss
                    })

        if not USE_GITHASH:
            short_sha = 'noHash'

        # if ((epoch % 1000) == 0) and (epoch != 0):
        if ((epoch % 1000) == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': "YOLOv1",
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                },
                False,
                filename=os.path.join(
                    checkpoint_path,
                    'ckpt_{}_ep{:05d}_loss{:.04f}_lr{}.pth.tar'.format(
                        short_sha, epoch, loss.item(), ([
                            param_group['lr']
                            for param_group in optimizer.param_groups
                        ])[0])))
Ejemplo n.º 3
0
def main(args):
    if args.im_size in [300, 512]:
        from model.detection.ssd_config import get_config
        cfg = get_config(args.im_size)
    else:
        print_error_message('{} image size not supported'.format(args.im_size))

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainTransform(cfg.image_size)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.center_variance, cfg.size_variance,
        cfg.iou_threshold)
    val_transform = ValTransform(cfg.image_size)

    if args.dataset in ['voc', 'pascal']:
        from data_loader.detection.voc import VOCDataset, VOC_CLASS_LIST
        train_dataset_2007 = VOCDataset(root_dir=args.data_path,
                                        transform=train_transform,
                                        target_transform=target_transform,
                                        is_training=True,
                                        split="VOC2007")
        train_dataset_2012 = VOCDataset(root_dir=args.data_path,
                                        transform=train_transform,
                                        target_transform=target_transform,
                                        is_training=True,
                                        split="VOC2012")
        train_dataset = torch.utils.data.ConcatDataset(
            [train_dataset_2007, train_dataset_2012])
        val_dataset = VOCDataset(root_dir=args.data_path,
                                 transform=val_transform,
                                 target_transform=target_transform,
                                 is_training=False,
                                 split="VOC2007")
        num_classes = len(VOC_CLASS_LIST)
    elif args.dataset == 'coco':
        from data_loader.detection.coco import COCOObjectDetection, COCO_CLASS_LIST
        train_dataset = COCOObjectDetection(root_dir=args.data_path,
                                            transform=train_transform,
                                            target_transform=target_transform,
                                            is_training=True)
        val_dataset = COCOObjectDetection(root_dir=args.data_path,
                                          transform=val_transform,
                                          target_transform=target_transform,
                                          is_training=False)
        num_classes = len(COCO_CLASS_LIST)
    else:
        print_error_message('{} dataset is not supported yet'.format(
            args.dataset))
        exit()
    cfg.NUM_CLASSES = num_classes

    # -----------------------------------------------------------------------------
    # Dataset loader
    # -----------------------------------------------------------------------------
    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = ssd(args, cfg)
    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False
    # -----------------------------------------------------------------------------
    # Optimizer and Criterion
    # -----------------------------------------------------------------------------
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.wd)

    criterion = MultiBoxLoss(neg_pos_ratio=cfg.neg_pos_ratio)

    # writer for logs
    writer = SummaryWriter(log_dir=args.save,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, cfg.image_size,
                                                     cfg.image_size))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    #model stats
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, cfg.image_size,
                                             cfg.image_size))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            cfg.image_size, cfg.image_size, flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus >= 1 else 'cpu'

    min_val_loss = float('inf')
    start_epoch = 0  # start from epoch 0 or last epoch
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.checkpoint,
                                    map_location=torch.device('cpu'))
            model.load_state_dict(checkpoint['state_dict'])
            min_val_loss = checkpoint['min_loss']
            start_epoch = checkpoint['epoch']
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    if num_gpus >= 1:
        model = torch.nn.DataParallel(model)
        model = model.to(device)
        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    if args.lr_type == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.lr_type == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.lr_type == 'clr':
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=args.cycle_len,
                                steps=args.steps,
                                gamma=args.gamma,
                                step=True)
    elif args.lr_type == 'cosine':
        from utilities.lr_scheduler import CosineLR
        lr_scheduler = CosineLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler not yet supported'.format(
            args.lr_type))
        exit()

    print_info_message(lr_scheduler)

    # -----------------------------------------------------------------------------
    # Training and validation loop
    # -----------------------------------------------------------------------------

    extra_info_ckpt = '{}_{}'.format(args.model, args.s)
    for epoch in range(start_epoch, args.epochs):
        curr_lr = lr_scheduler.step(epoch)
        optimizer.param_groups[0]['lr'] = curr_lr

        print_info_message('Running epoch {} at LR {}'.format(epoch, curr_lr))
        train_loss, train_cl_loss, train_loc_loss = train(train_loader,
                                                          model,
                                                          criterion,
                                                          optimizer,
                                                          device,
                                                          epoch=epoch)
        val_loss, val_cl_loss, val_loc_loss = validate(val_loader,
                                                       model,
                                                       criterion,
                                                       device,
                                                       epoch=epoch)
        # Save checkpoint
        is_best = val_loss < min_val_loss
        min_val_loss = min(val_loss, min_val_loss)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch,
                'model': args.model,
                'state_dict': weights_dict,
                'min_loss': min_val_loss
            }, is_best, args.save, extra_info_ckpt)

        writer.add_scalar('Detection/LR/learning_rate', round(curr_lr, 6),
                          epoch)
        writer.add_scalar('Detection/Loss/train', train_loss, epoch)
        writer.add_scalar('Detection/Loss/val', val_loss, epoch)
        writer.add_scalar('Detection/Loss/train_cls', train_cl_loss, epoch)
        writer.add_scalar('Detection/Loss/val_cls', val_cl_loss, epoch)
        writer.add_scalar('Detection/Loss/train_loc', train_loc_loss, epoch)
        writer.add_scalar('Detection/Loss/val_loc', val_loc_loss, epoch)
        writer.add_scalar('Detection/Complexity/Flops', min_val_loss,
                          math.ceil(flops))
        writer.add_scalar('Detection/Complexity/Params', min_val_loss,
                          math.ceil(num_params))

    writer.close()
Ejemplo n.º 4
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0

    elif args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GreenhouseDepth, GREENHOUSE_CLASS_LIST
        train_dataset = GreenhouseDepth(root=args.data_path,
                                        list_name='train_depth_ae.txt',
                                        train=True,
                                        size=crop_size,
                                        scale=args.scale,
                                        use_filter=True)
        val_dataset = GreenhouseRGBDSegmentation(root=args.data_path,
                                                 list_name='val_depth_ae.txt',
                                                 train=False,
                                                 size=crop_size,
                                                 scale=args.scale,
                                                 use_depth=True)
        class_weights = np.load('class_weights.npy')[:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        seg_classes = len(GREENHOUSE_CLASS_LIST)
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    from model.autoencoder.depth_autoencoder import espnetv2_autoenc
    args.classes = 3
    model = espnetv2_autoenc(args)

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 1, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0

    print('device : ' + device)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    #criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type,
    #                             device=device, ignore_idx=args.ignore_idx,
    #                             class_wts=class_wts.to(device))
    criterion = nn.MSELoss()
    # criterion = nn.L1Loss()

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    best_loss = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_seg
        # optimizer.param_groups[1]['lr'] = lr_seg

        # Train
        model.train()
        losses = AverageMeter()
        for i, batch in enumerate(train_loader):
            inputs = batch[1].to(device=device)  # Depth
            target = batch[0].to(device=device)  # RGB

            outputs = model(inputs)

            if device == 'cuda':
                loss = criterion(outputs, target).mean()
                if isinstance(outputs, (list, tuple)):
                    target_dev = outputs[0].device
                    outputs = gather(outputs, target_device=target_dev)
            else:
                loss = criterion(outputs, target)

            losses.update(loss.item(), inputs.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #             if not (i % 10):
            #                 print("Step {}, write images".format(i))
            #                 image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
            #                 writer.add_image('Autoencoder/results/train', image_grid, len(train_loader) * epoch + i)

            writer.add_scalar('Autoencoder/Loss/train', loss.item(),
                              len(train_loader) * epoch + i)

            print_info_message('Running batch {}/{} of epoch {}'.format(
                i + 1, len(train_loader), epoch + 1))

        train_loss = losses.avg

        writer.add_scalar('Autoencoder/LR/seg', round(lr_seg, 6), epoch)

        # Val
        if epoch % 5 == 0:
            losses = AverageMeter()
            with torch.no_grad():
                for i, batch in enumerate(val_loader):
                    inputs = batch[2].to(device=device)  # Depth
                    target = batch[0].to(device=device)  # RGB

                    outputs = model(inputs)

                    if device == 'cuda':
                        loss = criterion(outputs, target)  # .mean()
                        if isinstance(outputs, (list, tuple)):
                            target_dev = outputs[0].device
                            outputs = gather(outputs, target_device=target_dev)
                    else:
                        loss = criterion(outputs, target)

                    losses.update(loss.item(), inputs.size(0))

                    image_grid = torchvision.utils.make_grid(
                        outputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/results/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        inputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/inputs/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        target.data.cpu()).numpy()
                    writer.add_image('Autoencoder/target/val', image_grid,
                                     epoch)

            val_loss = losses.avg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # remember best miou and save checkpoint
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Autoencoder/Loss/val', val_loss, epoch)

    writer.close()
Ejemplo n.º 5
0
    def train(self):
        # Initialize saver, model parameters (hidden inputs to lstm)
        # Load model if needed.
        self.model.train()
        start_time = time.time()
        avg_loss, avg_elloss, avg_mtypeloss = 0.0, 0.0, 0.0
        epochs = self.tr_reader.tr_epochs
        steps = 0
        ncorrect, ntotal = 0, 0
        ncorrectOA, ntotalOA = 0, 0
        ncorrectB, ntotalB = 0, 0

        bestmodel, bestval, beststep = self.model, 0.0, 0
        bestFinalVal = 0.0

        readtime, convtime, processtime = 0, 0, 0

        # while ((steps < maxsteps and bestFinalVal < 0.999) or
        #        (CURR_SWITCHES < len(CURRICULUM_ORDER) - 1)):
        while steps < maxtrsteps:
            steps += 1
            # print(curr)
            rtimestart = time.time()
            b = self.tr_reader.next_train_batch()
            (leftb, leftlens, rightb, rightlens, docb, typesb, wididxsb,
             widprobsb) = (b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7])
            (ind, vals, dvsize) = docb
            readtime += (time.time() - rtimestart)
            ctimestart = time.time()
            ind = torch.LongTensor(ind)
            vals = torch.FloatTensor(vals)
            docb = torch.sparse.FloatTensor(ind.t(), vals, torch.Size(dvsize))
            (leftb, leftlens, rightb, rightlens, typesb,
             wididxsb, widprobsb) = (torch.FloatTensor(leftb),
                                     torch.LongTensor(leftlens),
                                     torch.FloatTensor(rightb),
                                     torch.LongTensor(rightlens),
                                     torch.FloatTensor(typesb),
                                     torch.LongTensor(wididxsb),
                                     torch.FloatTensor(widprobsb))
            (leftb, leftlens, rightb, rightlens, docb, typesb, wididxsb,
             widprobsb) = utils.toCudaVariable(device_id, leftb, leftlens,
                                               rightb, rightlens, docb, typesb,
                                               wididxsb, widprobsb)

            truewidvec = utils.toCudaVariable(device_id,
                                              torch.LongTensor([0] * bs))[0]
            convtime += (time.time() - ctimestart)
            ptimestart = time.time()

            rets = self.model.forward_context(leftb=leftb,
                                              leftlens=leftlens,
                                              rightb=rightb,
                                              rightlens=rightlens,
                                              docb=docb,
                                              wididxsb=wididxsb)
            (wididxscores, wididxprobs, mentype_probs) = (rets[0], rets[1],
                                                          rets[2])

            (loss, elloss,
             mentype_loss) = self.model.lossfunc(mentype=mentype,
                                                 predwidscores=wididxscores,
                                                 truewidvec=truewidvec,
                                                 mentype_probs=mentype_probs,
                                                 mentype_trueprobs=typesb)

            self.optstep(loss)
            loss = loss.data.cpu().numpy()[0]
            elloss = elloss.data.cpu().numpy()[0]
            mentype_loss = mentype_loss.data.cpu().numpy()[0]
            avg_loss += loss
            avg_elloss += elloss
            avg_mtypeloss += mentype_loss

            processtime += (time.time() - ptimestart)

            if steps % log_interval == 0:
                totaltime = readtime + processtime + convtime
                print()
                avg_loss = utils.round_all(avg_loss / log_interval, 3)
                avg_elloss = utils.round_all(avg_elloss / log_interval, 3)
                avg_mtypeloss = utils.round_all(avg_mtypeloss / log_interval,
                                                3)
                print("[{}, {}, rt:{:0.1f} secs ct:{:0.1f} pt:{:0.1f} "
                      "tt:{:0.1f} secs]: L:{} EL:{} MenTypL:{}".format(
                          steps, self.tr_reader.tr_epochs, readtime, convtime,
                          processtime, totaltime, avg_loss, avg_elloss,
                          avg_mtypeloss))
                readtime, convtime, processtime = 0, 0, 0
                # tracc = float(ncorrect)/float(ntotal)
                # oAtracc = float(ncorrectOA)/float(ntotalOA) if ntotalOA != 0.0 else 0.0
                # Btracc = float(ncorrectB)/float(ntotalB) if ntotalB != 0.0 else 0.0
                # avg_loss /= log_interval
                # time_elapsed = float(time.time() - start_time)/60.0
                # print("[{}, {}, {:0.1f} mins]: {}".format(
                #     steps, self.tr_reader.epochs,
                #     time_elapsed, avg_loss))
                # print("TrAcc: {} / {} : {:.3f}".format(
                #     ncorrect, ntotal, tracc))
                # print("OA : {}/{}: {}".format(ncorrectOA, ntotalOA, oAtracc))
                # print("Bool : {}/{}: {}".format(ncorrectB, ntotalB, Btracc))
                # avg_loss = 0.0
                # ntotal=0
                # ncorrect=0
                # ntotalOA = 0
                # ncorrectOA = 0
                # ntotalB = 0
                # ncorrectB = 0

            # if epochs != self.tr_reader.epochs or steps % 15000 == 0:
            if steps % 1000 == 0:
                print("Running Validation")
                print("Saving model: {}".format(ckptpath))
                bestmodel = copy.deepcopy(self.model)
                beststep = steps
                utils.save_checkpoint(m=bestmodel,
                                      o=self.optimizer,
                                      steps=steps,
                                      beststeps=beststep,
                                      path=ckptpath)
                self.validation()
                # (vt, vc, va) = self.validation_performance()
                # if va > bestval:
                #     bestval = va
                #     bestmodel = copy.deepcopy(self.model)
                #     beststep = steps
                # if bestval == 0.0 and va == 0.0: # keep latest model
                #     bestval = va
                #     bestmodel = copy.deepcopy(self.model)
                #     beststep = steps
                # # Check if final curricula is reached, then update bestFinalVal
                # if CURR_SWITCHES == len(CURRICULUM_ORDER) - 1:
                #     bestFinalVal = bestval
                # print("[##] Total: {}. Correct: {}. Acc: {:0.3f} "
                #       "[B:{:.3f} E:{}]".format(vt, vc, va, bestval, beststep))
                # print("[##] Best Final Val : {}\n".format(bestFinalVal))
                # print("Saving model: {}".format(ckptpath))
                # # Saving latest model
                # bestmodel = copy.deepcopy(self.model)
                # utils.save_checkpoint(m=bestmodel, o=self.optimizer,
                #                       steps=steps, beststeps=beststep,
                #                       path=ckptpath)
                # epochs = self.tr_reader.epochs
                self.model.train()

        return (bestmodel, bestval, beststep, steps)
Ejemplo n.º 6
0
def main(args):
    logdir = args.savedir + '/logs/'
    if not os.path.isdir(logdir):
        os.makedirs(logdir)

    my_logger = Logger(60066, logdir)

    if args.dataset == 'pascal':
        crop_size = (512, 512)
        args.scale = (0.5, 2.0)
    elif args.dataset == 'city':
        crop_size = (768, 768)
        args.scale = (0.5, 2.0)

    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[1], crop_size[0], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espnet':
        from model.espnet import espnet_seg
        args.classes = seg_classes
        model = espnet_seg(args)
    elif args.model == 'mobilenetv2_1_0':
        from model.mobilenetv2 import get_mobilenet_v2_1_0_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_1_0_seg(args)
    elif args.model == 'mobilenetv2_0_35':
        from model.mobilenetv2 import get_mobilenet_v2_0_35_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_35_seg(args)
    elif args.model == 'mobilenetv2_0_5':
        from model.mobilenetv2 import get_mobilenet_v2_0_5_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_5_seg(args)
    elif args.model == 'mobilenetv3_small':
        from model.mobilenetv3 import get_mobilenet_v3_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_small_seg(args)
    elif args.model == 'mobilenetv3_large':
        from model.mobilenetv3 import get_mobilenet_v3_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_large_seg(args)
    elif args.model == 'mobilenetv3_RE_small':
        from model.mobilenetv3 import get_mobilenet_v3_RE_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_small_seg(args)
    elif args.model == 'mobilenetv3_RE_large':
        from model.mobilenetv3 import get_mobilenet_v3_RE_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_large_seg(args)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = []
    params_dict = dict(model.named_parameters())
    others = args.weight_decay * 0.01
    for key, value in params_dict.items():
        if len(value.data.shape) == 4:
            if value.data.shape[1] == 1:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': 0.0
                }]
            else:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': args.weight_decay
                }]
        else:
            train_params += [{
                'params': [value],
                'lr': args.lr,
                'weight_decay': others
            }]

    args.learning_rate = args.lr
    optimizer = get_optimizer(args.optimizer, train_params, args)
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[1], crop_size[0]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[1], crop_size[0], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    start_epoch = 0
    epochs_len = args.epochs
    best_miou = 0.0

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers,
                                               drop_last=True)
    if args.dataset == 'city':
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)
    else:
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)

    lr_scheduler = get_lr_scheduler(args)

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])

    if args.fp_epochs > 0:
        print_info_message("========== MODEL FP WARMUP ===========")

        for epoch in range(args.fp_epochs):
            lr = lr_scheduler.step(epoch)

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            print_info_message(
                'Running epoch {} with learning rates: {:.6f}'.format(
                    epoch, lr))
            start_t = time.time()
            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device)
    if args.optimizer.startswith('Q'):
        optimizer.is_warmup = False
        print('exp_sensitivity calibration fin.')

    if not args.fp_train:
        model.module.quantized.fuse_model()
        model.module.quantized.qconfig = torch.quantization.get_default_qat_qconfig(
            'qnnpack')
        torch.quantization.prepare_qat(model.module.quantized, inplace=True)

    if args.resume:
        start_epoch = args.start_epoch
        if os.path.isfile(args.resume):
            print_info_message('Loading weights from {}'.format(args.resume))
            weight_dict = torch.load(args.resume, device)
            model.module.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for resume. Please check.')

    for epoch in range(start_epoch, args.epochs):
        lr = lr_scheduler.step(epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        print_info_message(
            'Running epoch {} with learning rates: {:.6f}'.format(epoch, lr))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)
        if is_best:
            model_file_name = args.savedir + '/model_' + str(epoch +
                                                             1) + '.pth'
            torch.save(weights_dict, model_file_name)
            print('weights saved in {}'.format(model_file_name))
        info = {
            'Segmentation/LR': round(lr, 6),
            'Segmentation/Loss/train': train_loss,
            'Segmentation/Loss/val': val_loss,
            'Segmentation/mIOU/train': miou_train,
            'Segmentation/mIOU/val': miou_val,
            'Segmentation/Complexity/Flops': best_miou,
            'Segmentation/Complexity/Params': best_miou,
        }

        for tag, value in info.items():
            if tag == 'Segmentation/Complexity/Flops':
                my_logger.scalar_summary(tag, value, math.ceil(flops))
            elif tag == 'Segmentation/Complexity/Params':
                my_logger.scalar_summary(tag, value, math.ceil(num_params))
            else:
                my_logger.scalar_summary(tag, value, epoch + 1)

    print_info_message("========== TRAINING FINISHED ===========")
Ejemplo n.º 7
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Ejemplo n.º 8
0
    def run(self, *args, **kwargs):
        kwargs['need_attn'] = False

        if self.opts.warm_up:
            self.warm_up(args=args, kwargs=kwargs)

        if self.resume is not None:
            # find the LR value
            for epoch in range(self.start_epoch):
                self.lr_scheduler.step(epoch)

        eval_stats_dict = dict()

        for epoch in range(self.start_epoch, self.opts.epochs):
            epoch_lr = self.lr_scheduler.step(epoch)

            self.optimizer = update_optimizer(optimizer=self.optimizer, lr_value=epoch_lr)

            # Uncomment this line if you want to check the optimizer's LR is updated correctly
            # assert read_lr_from_optimzier(self.optimizer) == epoch_lr

            train_acc, train_loss = self.training(epoch=epoch, lr=epoch_lr, args=args, kwargs=kwargs)
            val_acc, val_loss = self.validation(epoch=epoch, lr=epoch_lr, args=args, kwargs=kwargs)
            eval_stats_dict[epoch] = val_acc
            gc.collect()

            # remember best accuracy and save checkpoint for best model
            is_best = val_acc >= self.best_acc
            self.best_acc = max(val_acc, self.best_acc)

            model_state = self.mi_model.module.state_dict() if isinstance(self.mi_model, torch.nn.DataParallel) \
                else self.mi_model.state_dict()

            optimizer_state = self.optimizer.state_dict()

            save_checkpoint(epoch=epoch,
                            model_state=model_state,
                            optimizer_state=optimizer_state,
                            best_perf=self.best_acc,
                            save_dir=self.opts.savedir,
                            is_best=is_best,
                            keep_best_k_models=self.opts.keep_best_k_models
                            )

            self.logger.add_scalar('LR', round(epoch_lr, 6), epoch)
            self.logger.add_scalar('TrainingLoss', train_loss, epoch)
            self.logger.add_scalar('TrainingAcc', train_acc, epoch)

            self.logger.add_scalar('ValidationLoss', val_loss, epoch)
            self.logger.add_scalar('ValidationAcc', val_acc, epoch)

        # dump the validation epoch id and accuracy data, so that it could be used for filtering later on
        eval_stats_dict_sort = {k: v for k, v in sorted(eval_stats_dict.items(),
                                                        key=lambda item: item[1],
                                                        reverse=True
                                                        )}

        eval_stats_fname = '{}/val_stats_bag_{}_word_{}_{}_{}'.format(
            self.opts.savedir,
            self.opts.bag_size,
            self.opts.word_size,
            self.opts.attn_fn,
            self.opts.attn_type,
        )

        writer = DictWriter(file_name=eval_stats_fname, format='json')
        # if json file does not exist
        if not os.path.isfile(eval_stats_fname):
            writer.write(data_dict=eval_stats_dict_sort)
        else:
            with open(eval_stats_fname, 'r') as json_file:
                eval_stats_dict_old = json.load(json_file)
            eval_stats_dict_old.update(eval_stats_dict_sort)

            eval_stats_dict_updated = {k: v for k, v in sorted(eval_stats_dict_old.items(),
                                                               key=lambda item: item[1],
                                                               reverse=True
                                                               )}
            writer.write(data_dict=eval_stats_dict_updated)

        self.logger.close()
Ejemplo n.º 9
0
def main(args):
    # -----------------------------------------------------------------------------
    # Create model
    # -----------------------------------------------------------------------------
    if args.model == 'dicenet':
        from model.classification import dicenet as net
        model = net.CNNModel(args)
    elif args.model == 'espnetv2':
        from model.classification import espnetv2 as net
        model = net.EESPNet(args)
    elif args.model == 'shufflenetv2':
        from model.classification import shufflenetv2 as net
        model = net.CNNModel(args)
    else:
        print_error_message('Model {} not yet implemented'.format(args.model))
        exit()

    if args.finetune:
        # laod the weights for finetuning
        if os.path.isfile(args.weights_ft):
            pretrained_dict = torch.load(args.weights_ft,
                                         map_location=torch.device('cpu'))
            print_info_message('Loading pretrained basenet model weights')
            model_dict = model.state_dict()

            overlap_dict = {
                k: v
                for k, v in model_dict.items() if k in pretrained_dict
            }

            total_size_overlap = 0
            for k, v in enumerate(overlap_dict):
                total_size_overlap += torch.numel(overlap_dict[v])

            total_size_pretrain = 0
            for k, v in enumerate(pretrained_dict):
                total_size_pretrain += torch.numel(pretrained_dict[v])

            if len(overlap_dict) == 0:
                print_error_message(
                    'No overlaping weights between model file and pretrained weight file. Please check'
                )

            print_info_message('Overlap ratio of weights: {:.2f} %'.format(
                (total_size_overlap * 100.0) / total_size_pretrain))

            model_dict.update(overlap_dict)
            model.load_state_dict(model_dict, strict=False)
            print_info_message('Pretrained basenet model loaded!!')
        else:
            print_error_message('Unable to find the weights: {}'.format(
                args.weights_ft))

    # -----------------------------------------------------------------------------
    # Writer for logging
    # -----------------------------------------------------------------------------
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)
    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.randn(1, 70, args.inpSize,
                                                    args.inpSize))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    # network properties
    num_params = model_parameters(model)
    flops = compute_flops(model)
    print_info_message('FLOPs: {:.2f} million'.format(flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    best_acc = 0.0
    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus >= 1 else 'cpu'
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'],
                                  map_location=torch.device(device))
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    # -----------------------------------------------------------------------------
    # Loss Fn
    # -----------------------------------------------------------------------------
    if args.dataset == 'imagenet':
        criterion = nn.CrossEntropyLoss()
        acc_metric = 'Top-1'
    elif args.dataset == 'coco':
        criterion = nn.BCEWithLogitsLoss()
        acc_metric = 'F1'
    elif args.dataset == 'Heart':
        criterion = nn.L1Loss()
        acc_metric = 'Test'
    else:
        print_error_message('{} dataset not yet supported'.format(
            args.dataset))

    if num_gpus >= 1:
        model = torch.nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # -----------------------------------------------------------------------------
    # Data Loaders
    # -----------------------------------------------------------------------------
    # Data loading code
    if args.dataset == 'imagenet':
        train_loader, val_loader = img_loader.data_loaders(args)
        # import the loaders too
        from utilities.train_eval_classification import train, validate
    elif args.dataset == 'coco':
        from data_loader.classification.coco import COCOClassification
        train_dataset = COCOClassification(root=args.data,
                                           split='train',
                                           year='2017',
                                           inp_size=args.inpSize,
                                           scale=args.scale,
                                           is_training=True)
        val_dataset = COCOClassification(root=args.data,
                                         split='val',
                                         year='2017',
                                         inp_size=args.inpSize,
                                         is_training=False)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

        # import the loaders too
        from utilities.train_eval_classification import train_multi as train
        from utilities.train_eval_classification import validate_multi as validate
    elif args.dataset == 'Heart':
        from utilities.train_eval_classification import train, validate

        def load_npy(npy_path):
            try:
                data = np.load(npy_path).item()
            except:
                data = np.load(npy_path)
            return data

        def loadData(data_path):
            npy_data = load_npy(data_path)
            signals = npy_data['signals']
            gts = npy_data['gts']
            return signals, gts

        ht_img_width, ht_img_height = args.inpSize, args.inpSize
        ht_batch_size = args.batch_size
        signal_length = args.channels
        signals_train, gts_train = loadData(
            '../DiCENeT/CardioNet/data_train/fps7_sample10_2D_train.npy')
        signals_val, gts_val = loadData(
            '../DiCENeT/CardioNet/data_train/fps7_sample10_2D_val.npy')
        from data_loader.classification.heart import HeartDataGenerator
        heart_train_data = HeartDataGenerator(signals_train, gts_train,
                                              ht_batch_size)
        # heart_train_data.squeeze
        heart_val_data = HeartDataGenerator(signals_val, gts_val,
                                            ht_batch_size)
        # heart_val_data.squeeze
        train_loader = torch.utils.data.DataLoader(heart_train_data,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(heart_val_data,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

    else:
        print_error_message('{} dataset not yet supported'.format(
            args.dataset))

    # -----------------------------------------------------------------------------
    # LR schedulers
    # -----------------------------------------------------------------------------
    if args.scheduler == 'fixed':
        step_sizes = args.steps
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        from utilities.lr_scheduler import CyclicLR
        step_sizes = args.steps
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr, max_epochs=args.epochs)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max)
    else:
        print_error_message('Scheduler ({}) not yet implemented'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    # set up the epoch variable in case resuming training
    if args.start_epoch != 0:
        for epoch in range(args.start_epoch):
            lr_scheduler.step(epoch)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    # -----------------------------------------------------------------------------
    # Training and Val Loop
    # -----------------------------------------------------------------------------

    extra_info_ckpt = args.model + '_' + str(args.s)
    for epoch in range(args.start_epoch, args.epochs):
        lr_log = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_log
        print_info_message("LR for epoch {} = {:.5f}".format(epoch, lr_log))
        train_acc, train_loss = train(data_loader=train_loader,
                                      model=model,
                                      criteria=criterion,
                                      optimizer=optimizer,
                                      epoch=epoch,
                                      device=device)
        # evaluate on validation set
        val_acc, val_loss = validate(data_loader=val_loader,
                                     model=model,
                                     criteria=criterion,
                                     device=device)

        # remember best prec@1 and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': weights_dict,
                'best_prec1': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Classification/LR/learning_rate', lr_log, epoch)
        writer.add_scalar('Classification/Loss/Train', train_loss, epoch)
        writer.add_scalar('Classification/Loss/Val', val_loss, epoch)
        writer.add_scalar('Classification/{}/Train'.format(acc_metric),
                          train_acc, epoch)
        writer.add_scalar('Classification/{}/Val'.format(acc_metric), val_acc,
                          epoch)
        writer.add_scalar('Classification/Complexity/Top1_vs_flops', best_acc,
                          round(flops, 2))
        writer.add_scalar('Classification/Complexity/Top1_vs_params', best_acc,
                          round(num_params, 2))

    writer.close()
Ejemplo n.º 10
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegCls, GREENHOUSE_CLASS_LIST
        train_dataset = GreenhouseRGBDSegCls(
            root=args.data_path,
            list_name='train_greenhouse_mult.txt',
            train=True,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth)
        val_dataset = GreenhouseRGBDSegCls(root=args.data_path,
                                           list_name='val_greenhouse_mult.txt',
                                           train=False,
                                           size=crop_size,
                                           scale=args.scale,
                                           use_depth=args.use_depth)
        class_weights = np.load('class_weights.npy')[:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        seg_classes = len(GREENHOUSE_CLASS_LIST)
        color_encoding = OrderedDict([('end_of_plant', (0, 255, 0)),
                                      ('other_part_of_plant', (0, 255, 255)),
                                      ('artificial_objects', (255, 0, 0)),
                                      ('ground', (255, 255, 0)),
                                      ('background', (0, 0, 0))])
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espdnet':
        from model.segmentation.espdnet_mult import espdnet_mult
        args.classes = seg_classes
        args.cls_classes = 5
        model = espdnet_mult(args)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    if args.use_depth:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }, {
            'params': model.get_depth_encoder_params(),
            'lr': args.lr
        }]
    else:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    print('device : ' + device)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    cls_class_weight = calc_cls_class_weight(train_loader, 5)
    print(cls_class_weight)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion_seg = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))

    criterion_cls = nn.CrossEntropyLoss(
        weight=torch.from_numpy(cls_class_weight).float().to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion_seg = criterion_seg.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion_seg = DataParallelCriteria(criterion_seg)
            criterion_seg = criterion_seg.cuda()

        criterion_cls = criterion_cls.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg
        if args.use_depth:
            optimizer.param_groups[2]['lr'] = lr_base

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))
        miou_train, train_loss, train_seg_loss, train_cls_loss = train(
            model,
            train_loader,
            optimizer,
            criterion_seg,
            seg_classes,
            epoch,
            criterion_cls,
            device=device,
            use_depth=args.use_depth)
        miou_val, val_loss, val_seg_loss, val_cls_loss = val(
            model,
            val_loader,
            criterion_seg,
            criterion_cls,
            seg_classes,
            device=device,
            use_depth=args.use_depth)

        batch = iter(val_loader).next()
        if args.use_depth:
            in_training_visualization_2(model,
                                        images=batch[0].to(device=device),
                                        depths=batch[2].to(device=device),
                                        labels=batch[1].to(device=device),
                                        class_encoding=color_encoding,
                                        writer=writer,
                                        epoch=epoch,
                                        data='Segmentation',
                                        device=device)
        else:
            in_training_visualization_2(model,
                                        images=batch[0].to(device=device),
                                        labels=batch[1].to(device=device),
                                        class_encoding=color_encoding,
                                        writer=writer,
                                        epoch=epoch,
                                        data='Segmentation',
                                        device=device)


#            image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
#            writer.add_image('Segmentation/results/val', image_grid, epoch)

# remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/SegLoss/train', train_seg_loss, epoch)
        writer.add_scalar('Segmentation/ClsLoss/train', train_cls_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/SegLoss/val', val_seg_loss, epoch)
        writer.add_scalar('Segmentation/ClsLoss/val', val_cls_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Ejemplo n.º 11
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST, color_encoding
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               coarse=False)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 10 / 2.8149201869965
        class_wts[1] = 10 / 6.9850029945374
        class_wts[2] = 10 / 3.7890393733978
        class_wts[3] = 10 / 9.9428062438965
        class_wts[4] = 10 / 9.7702074050903
        class_wts[5] = 10 / 9.5110931396484
        class_wts[6] = 10 / 10.311357498169
        class_wts[7] = 10 / 10.026463508606
        class_wts[8] = 10 / 4.6323022842407
        class_wts[9] = 10 / 9.5608062744141
        class_wts[10] = 10 / 7.8698215484619
        class_wts[11] = 10 / 9.5168733596802
        class_wts[12] = 10 / 10.373730659485
        class_wts[13] = 10 / 6.6616044044495
        class_wts[14] = 10 / 10.260489463806
        class_wts[15] = 10 / 10.287888526917
        class_wts[16] = 10 / 10.289801597595
        class_wts[17] = 10 / 10.405355453491
        class_wts[18] = 10 / 10.138095855713
        class_wts[19] = 0.0

    elif args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GREENHOUSE_CLASS_LIST, color_encoding
        train_dataset = GreenhouseRGBDSegmentation(
            root=args.data_path,
            list_name=args.train_list,
            train=True,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth,
            use_traversable=args.greenhouse_use_trav)
        val_dataset = GreenhouseRGBDSegmentation(
            root=args.data_path,
            list_name=args.val_list,
            train=False,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth,
            use_traversable=args.greenhouse_use_trav)
        class_weights = np.load('class_weights.npy')  # [:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        print(GREENHOUSE_CLASS_LIST)
        seg_classes = len(GREENHOUSE_CLASS_LIST)
#        color_encoding = OrderedDict([
#            ('end_of_plant', (0, 255, 0)),
#            ('other_part_of_plant', (0, 255, 255)),
#            ('artificial_objects', (255, 0, 0)),
#            ('ground', (255, 255, 0)),
#            ('background', (0, 0, 0))
#        ])
    elif args.dataset == 'ishihara':
        print(args.use_depth)
        from data_loader.segmentation.ishihara_rgbd import IshiharaRGBDSegmentation, ISHIHARA_RGBD_CLASS_LIST
        train_dataset = IshiharaRGBDSegmentation(
            root=args.data_path,
            list_name='ishihara_rgbd_train.txt',
            train=True,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth)
        val_dataset = IshiharaRGBDSegmentation(
            root=args.data_path,
            list_name='ishihara_rgbd_val.txt',
            train=False,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth)

        seg_classes = len(ISHIHARA_RGBD_CLASS_LIST)

        class_wts = torch.ones(seg_classes)

        color_encoding = OrderedDict([('Unlabeled', (0, 0, 0)),
                                      ('Building', (70, 70, 70)),
                                      ('Fence', (190, 153, 153)),
                                      ('Others', (72, 0, 90)),
                                      ('Pedestrian', (220, 20, 60)),
                                      ('Pole', (153, 153, 153)),
                                      ('Road ', (157, 234, 50)),
                                      ('Road', (128, 64, 128)),
                                      ('Sidewalk', (244, 35, 232)),
                                      ('Vegetation', (107, 142, 35)),
                                      ('Car', (0, 0, 255)),
                                      ('Wall', (102, 102, 156)),
                                      ('Traffic ', (220, 220, 0))])
    elif args.dataset == 'sun':
        print(args.use_depth)
        from data_loader.segmentation.sun_rgbd import SUNRGBDSegmentation, SUN_RGBD_CLASS_LIST
        train_dataset = SUNRGBDSegmentation(root=args.data_path,
                                            list_name='sun_rgbd_train.txt',
                                            train=True,
                                            size=crop_size,
                                            ignore_idx=args.ignore_idx,
                                            scale=args.scale,
                                            use_depth=args.use_depth)
        val_dataset = SUNRGBDSegmentation(root=args.data_path,
                                          list_name='sun_rgbd_val.txt',
                                          train=False,
                                          size=crop_size,
                                          ignore_idx=args.ignore_idx,
                                          scale=args.scale,
                                          use_depth=args.use_depth)

        seg_classes = len(SUN_RGBD_CLASS_LIST)

        class_wts = torch.ones(seg_classes)

        color_encoding = OrderedDict([('Background', (0, 0, 0)),
                                      ('Bed', (0, 255, 0)),
                                      ('Books', (70, 70, 70)),
                                      ('Ceiling', (190, 153, 153)),
                                      ('Chair', (72, 0, 90)),
                                      ('Floor', (220, 20, 60)),
                                      ('Furniture', (153, 153, 153)),
                                      ('Objects', (157, 234, 50)),
                                      ('Picture', (128, 64, 128)),
                                      ('Sofa', (244, 35, 232)),
                                      ('Table', (107, 142, 35)),
                                      ('TV', (0, 0, 255)),
                                      ('Wall', (102, 102, 156)),
                                      ('Window', (220, 220, 0))])
    elif args.dataset == 'camvid':
        print(args.use_depth)
        from data_loader.segmentation.camvid import CamVidSegmentation, CAMVID_CLASS_LIST, color_encoding
        train_dataset = CamVidSegmentation(
            root=args.data_path,
            list_name='train_camvid.txt',
            train=True,
            size=crop_size,
            scale=args.scale,
            label_conversion=args.label_conversion,
            normalize=args.normalize)
        val_dataset = CamVidSegmentation(
            root=args.data_path,
            list_name='val_camvid.txt',
            train=False,
            size=crop_size,
            scale=args.scale,
            label_conversion=args.label_conversion,
            normalize=args.normalize)

        if args.label_conversion:
            from data_loader.segmentation.greenhouse import GREENHOUSE_CLASS_LIST, color_encoding
            seg_classes = len(GREENHOUSE_CLASS_LIST)
            class_wts = torch.ones(seg_classes)
        else:
            seg_classes = len(CAMVID_CLASS_LIST)
            tmp_loader = torch.utils.data.DataLoader(train_dataset,
                                                     batch_size=1,
                                                     shuffle=False)

            class_wts = calc_cls_class_weight(tmp_loader,
                                              seg_classes,
                                              inverted=True)
            class_wts = torch.from_numpy(class_wts).float().to(device)
            #            class_wts = torch.ones(seg_classes)
            print("class weights : {}".format(class_wts))

        args.use_depth = False
    elif args.dataset == 'forest':
        from data_loader.segmentation.freiburg_forest import FreiburgForestDataset, FOREST_CLASS_LIST, color_encoding
        train_dataset = FreiburgForestDataset(train=True,
                                              size=crop_size,
                                              scale=args.scale,
                                              normalize=args.normalize)
        val_dataset = FreiburgForestDataset(train=False,
                                            size=crop_size,
                                            scale=args.scale,
                                            normalize=args.normalize)

        seg_classes = len(FOREST_CLASS_LIST)
        tmp_loader = torch.utils.data.DataLoader(train_dataset,
                                                 batch_size=1,
                                                 shuffle=False)

        class_wts = calc_cls_class_weight(tmp_loader,
                                          seg_classes,
                                          inverted=True)
        class_wts = torch.from_numpy(class_wts).float().to(device)
        #        class_wts = torch.ones(seg_classes)
        print("class weights : {}".format(class_wts))

        args.use_depth = False
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espdnet':
        from model.segmentation.espdnet import espdnet_seg
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnet_seg(args)
    elif args.model == 'espdnetue':
        from model.segmentation.espdnet_ue import espdnetue_seg2
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnetue_seg2(args, fix_pyr_plane_proj=True)
    elif args.model == 'deeplabv3':
        # from model.segmentation.deeplabv3 import DeepLabV3
        from torchvision.models.segmentation.segmentation import deeplabv3_resnet101

        args.classes = seg_classes
        # model = DeepLabV3(seg_classes)
        model = deeplabv3_resnet101(num_classes=seg_classes, aux_loss=True)
        torch.backends.cudnn.enabled = False
    elif args.model == 'unet':
        from model.segmentation.unet import UNet
        model = UNet(in_channels=3, out_channels=seg_classes)
#        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
#            in_channels=3, out_channels=seg_classes, init_features=32, pretrained=False)

    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    if args.model == 'deeplabv3' or args.model == 'unet':
        train_params = [{'params': model.parameters(), 'lr': args.lr}]

    elif args.use_depth:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }, {
            'params': model.get_depth_encoder_params(),
            'lr': args.lr * args.lr_mult
        }]
    else:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model, input_to_model=torch.Tensor(1, 3, 288, 480))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    print('device : ' + device)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))
    nid_loss = NIDLoss(image_bin=32,
                       label_bin=seg_classes) if args.use_nid else None

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss = DataParallelCriteria(nid_loss)
                nid_loss = nid_loss.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=20,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        if len(optimizer.param_groups) > 1:
            optimizer.param_groups[1]['lr'] = lr_seg
        if args.use_depth:
            optimizer.param_groups[2]['lr'] = lr_base

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))

        if args.model == 'espdnetue' or (
            (args.model == 'deeplabv3' or args.model == 'unet')
                and args.use_aux):
            from utilities.train_eval_seg import train_seg_ue as train
            from utilities.train_eval_seg import val_seg_ue as val
        else:
            from utilities.train_eval_seg import train_seg as train
            from utilities.train_eval_seg import val_seg as val

        iou_train, train_loss = train(model,
                                      train_loader,
                                      optimizer,
                                      criterion,
                                      seg_classes,
                                      epoch,
                                      device=device,
                                      use_depth=args.use_depth,
                                      add_criterion=nid_loss)
        iou_val, val_loss = val(model,
                                val_loader,
                                criterion,
                                seg_classes,
                                device=device,
                                use_depth=args.use_depth,
                                add_criterion=nid_loss)

        batch_train = iter(train_loader).next()
        batch = iter(val_loader).next()
        if args.use_depth:
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                depths=batch_train[2].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          depths=batch[2].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)

            image_grid = torchvision.utils.make_grid(
                batch[2].to(device=device).data.cpu()).numpy()
            print(type(image_grid))
            writer.add_image('Segmentation/depths', image_grid, epoch)
        else:
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)


#            image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
#            writer.add_image('Segmentation/results/val', image_grid, epoch)

# remember best miou and save checkpoint
        miou_val = iou_val[[1, 2, 3]].mean()
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train',
                          iou_train[[1, 2, 3]].mean(), epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/plant_IOU/val', iou_val[1], epoch)
        writer.add_scalar('Segmentation/ao_IOU/val', iou_val[2], epoch)
        writer.add_scalar('Segmentation/ground_IOU/val', iou_val[3], epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Ejemplo n.º 12
0
def main():
    device = 'cuda'

    now = datetime.datetime.now()
    now += datetime.timedelta(hours=9)
    timestr = now.strftime("%Y%m%d-%H%M%S")
    use_depth_str = "_rgbd" if args.use_depth else "_rgb"
    if args.use_depth:
        trainable_fusion_str = "_gated" if args.trainable_fusion else "_naive"
    else:
        trainable_fusion_str = ""

    save_path = '{}/model_{}_{}/{}'.format(args.save, args.model, args.dataset,
                                           timestr)

    print(save_path)

    if not os.path.isdir(save_path):
        os.makedirs(save_path)
    tgt_train_lst = osp.join(save_path, 'tgt_train.lst')
    save_pred_path = osp.join(save_path, 'pred')
    if not os.path.isdir(save_pred_path):
        os.makedirs(save_pred_path)
    writer = SummaryWriter(save_path)

    # Dataset
    from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentationTrav, GREENHOUSE_CLASS_LIST
    args.classes = len(GREENHOUSE_CLASS_LIST)
    travset = GreenhouseRGBDSegmentationTrav(list_name=args.data_trav_list,
                                             use_depth=args.use_depth)

    class_encoding = OrderedDict([('end_of_plant', (0, 255, 0)),
                                  ('other_part_of_plant', (0, 255, 255)),
                                  ('artificial_objects', (255, 0, 0)),
                                  ('ground', (255, 255, 0)),
                                  ('background', (0, 0, 0))])

    # Dataloader for generating the pseudo-labels
    travloader = torch.utils.data.DataLoader(travset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=0,
                                             pin_memory=args.pin_memory)

    # Model
    from model.segmentation.espdnet_ue import espdnetue_seg2
    args.weights = args.restore_from
    model = espdnetue_seg2(args,
                           load_entire_weights=True,
                           fix_pyr_plane_proj=True)
    model.to(device)

    generate_label(model, travloader, save_pred_path, tgt_train_lst)
    # Datset for training
    from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation
    trainset = GreenhouseRGBDSegmentation(list_name=tgt_train_lst,
                                          use_depth=args.use_depth,
                                          use_traversable=True)
    testset = GreenhouseRGBDSegmentation(list_name=args.data_test_list,
                                         use_depth=args.use_depth,
                                         use_traversable=True)

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=args.pin_memory)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=0,
                                             pin_memory=args.pin_memory)

    # Loss
    class_weights = torch.tensor([1.0, 0.2, 1.0, 1.0, 0.0]).to(device)

    if args.use_uncertainty:
        criterion = UncertaintyWeightedSegmentationLoss(
            args.classes, class_weights=class_weights)
    else:
        criterion = SegmentationLoss(n_classes=args.classes,
                                     device=device,
                                     class_weights=class_weights)

    criterion_test = SegmentationLoss(n_classes=args.classes,
                                      device=device,
                                      class_weights=class_weights)

    # Optimizer
    if args.use_depth:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.learning_rate * 0.1
        }, {
            'params': model.get_segment_params(),
            'lr': args.learning_rate
        }, {
            'params': model.get_depth_encoder_params(),
            'lr': args.learning_rate
        }]
    else:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.learning_rate * 0.1
        }, {
            'params': model.get_segment_params(),
            'lr': args.learning_rate
        }]

    if args.optimizer == 'SGD':
        optimizer = optim.SGD(train_params,
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(train_params,
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)

    scheduler = optim.lr_scheduler.CyclicLR(
        optimizer,
        base_lr=args.learning_rate,
        max_lr=args.learning_rate * 10,
        step_size_up=10,
        step_size_down=20,
        cycle_momentum=True if args.optimizer == 'SGD' else False)
    #    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100], gamma=0.5)

    best_miou = 0.0
    for i in range(0, args.epoch):

        # Run a training epoch
        train(trainloader,
              model,
              criterion,
              device,
              optimizer,
              class_encoding,
              i,
              writer=writer)

        # Update the learning rate
        scheduler.step()
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        #        optimizer.param_groups[0]['lr'] = lr_base
        #        if len(optimizer.param_groups) > 1:
        #            optimizer.param_groups[1]['lr'] = lr_seg
        #        if args.use_depth:
        #            optimizer.param_groups[2]['lr'] = lr_base * 10

        new_miou = test(testloader,
                        model,
                        criterion_test,
                        device,
                        optimizer,
                        class_encoding,
                        i,
                        writer=writer)

        # Save the weights if it produces the best IoU
        is_best = new_miou > best_miou
        best_miou = max(new_miou, best_miou)
        model.to(device)
        #        weights_dict = model.module.state_dict() if device == 'cuda' else model.state_dict()
        weights_dict = model.state_dict()
        extra_info_ckpt = '{}'.format(args.model)
        if is_best:
            save_checkpoint(
                {
                    'epoch': i + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_miou': best_miou,
                    'optimizer': optimizer.state_dict(),
                }, is_best, save_path, extra_info_ckpt)
Ejemplo n.º 13
0
def main():
    device = 'cuda'

    now = datetime.datetime.now()
    now += datetime.timedelta(hours=9)
    timestr = now.strftime("%Y%m%d-%H%M%S")

    save_path = '{}/model_{}_{}/{}'.format(args.save, args.model, args.dataset,
                                           timestr)

    print(save_path)

    if not os.path.isdir(save_path):
        os.makedirs(save_path)
    save_pred_path = osp.join(save_path, 'pred')
    if not os.path.isdir(save_pred_path):
        os.makedirs(save_pred_path)
    writer = SummaryWriter(save_path)

    #
    # Dataset
    #
    from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentationTrav, GREENHOUSE_CLASS_LIST
    args.classes = len(GREENHOUSE_CLASS_LIST)
    trav_train_set = GreenhouseRGBDSegmentationTrav(
        list_name=args.data_train_list, use_depth=args.use_depth)
    trav_test_set = GreenhouseRGBDSegmentationTrav(
        list_name=args.data_test_list, use_depth=args.use_depth)

    #
    # Dataloader for generating the pseudo-labels
    #
    trav_train_loader = torch.utils.data.DataLoader(trav_train_set,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=0,
                                                    pin_memory=args.pin_memory)
    trav_test_loader = torch.utils.data.DataLoader(
        trav_test_set,
        batch_size=len(trav_test_set),
        shuffle=False,
        num_workers=0,
        pin_memory=args.pin_memory)

    #
    # Models
    #
    # Label Probability
    from model.classification.label_prob_estimator import LabelProbEstimator
    in_channels = 32 if args.feature_construction == 'concat' else 16
    prob_model = LabelProbEstimator(in_channels=in_channels,
                                    spatial=args.spatial)
    prob_model.to(device)

    # Segmentation
    from model.segmentation.espdnet_ue import espdnetue_seg2
    args.weights = args.restore_from
    seg_model = espdnetue_seg2(args,
                               load_entire_weights=True,
                               fix_pyr_plane_proj=True)
    seg_model.to(device)

    criterion = SelectiveBCE()
    #    # Datset for training
    #    from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation
    #    trainset = GreenhouseRGBDSegmentation(list_name=tgt_train_lst, use_depth=args.use_depth, use_traversable=True)
    #    testset  = GreenhouseRGBDSegmentation(list_name=args.data_test_list, use_depth=args.use_depth, use_traversable=True)
    #
    #    trainloader = torch.utils.data.DataLoader(
    #        trainset, batch_size=args.batch_size, shuffle=True,
    #        num_workers=0, pin_memory=args.pin_memory)
    #    testloader = torch.utils.data.DataLoader(
    #        testset, batch_size=args.batch_size, shuffle=True,
    #        num_workers=0, pin_memory=args.pin_memory)
    #
    #    # Loss
    #    class_weights = torch.tensor([1.0, 0.2, 1.0, 1.0, 0.0]).to(device)
    #
    #    criterion = nn.BCEWithLogitsLoss().to(device)
    #
    #    # Optimizer
    #    if args.use_depth:
    #        train_params = [{'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1},
    #                        {'params': model.get_segment_params(), 'lr': args.learning_rate},
    #                        {'params': model.get_depth_encoder_params(), 'lr': args.learning_rate}]
    #    else:
    #        train_params = [{'params': model.get_basenet_params(), 'lr': args.learning_rate * 0.1},
    #                        {'params': model.get_segment_params(), 'lr': args.learning_rate}]
    #
    if args.optimizer == 'SGD':
        optimizer = optim.SGD(prob_model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(prob_model.parameters(),
                               lr=args.learning_rate,
                               weight_decay=args.weight_decay)

    if args.lr_scheduling == "cyclic":
        scheduler = optim.lr_scheduler.CyclicLR(
            optimizer,
            base_lr=args.learning_rate,
            max_lr=args.learning_rate * 10,
            step_size_up=10,
            step_size_down=20,
            cycle_momentum=True if args.optimizer == 'SGD' else False)
    else:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=[50, 150],
                                                   gamma=0.5)
#
#    best_miou = 0.0
    c = 1.0
    loss_old = 1000000
    for epoch in range(0, args.epoch):
        #        calculate_iou_with_different_threshold(trav_test_loader, seg_model, prob_model, c, writer, device=device, writer_idx=epoch, histogram=False)
        calculate_iou(trav_test_loader,
                      seg_model,
                      prob_model,
                      c,
                      writer,
                      device=device,
                      writer_idx=epoch)
        # Run a training epoch
        train(trav_train_loader, prob_model, seg_model, criterion, device,
              optimizer, epoch, writer)
        scheduler.step()

        ret_dict = test(trav_test_loader, prob_model, seg_model, criterion,
                        device, epoch, writer)
        c = ret_dict["c"]
        loss = ret_dict["loss"]

        extra_info_ckpt = '{}_epoch_{}_c_{}'.format(args.model, epoch, c)
        weights_dict = prob_model.state_dict()
        if loss < loss_old:
            print("Save weights")
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_miou': 0.0,
                    'optimizer': optimizer.state_dict(),
                }, loss < loss_old, save_path, extra_info_ckpt)

            loss_old = loss

        print("c = {}".format(c))
Ejemplo n.º 14
0
def train(params):

    # future work variable
    dataset = params["dataset"]
    input_height = params["input_height"]
    input_width = params["input_width"]

    data_path = params["data_path"]
    val_data_path = params["val_data_path"]
    val_datalist_path = params["val_datalist_path"]
    datalist_path = params["datalist_path"]
    class_path = params["class_path"]
    batch_size = params["batch_size"]
    num_epochs = params["num_epochs"]
    learning_rate = params["lr"]
    checkpoint_path = params["checkpoint_path"]

    USE_AUGMENTATION = params["use_augmentation"]
    USE_GTCHECKER = params["use_gtcheck"]
    USE_VISDOM = params["use_visdom"]

    USE_GITHASH = params["use_githash"]
    num_class = params["num_class"]
    num_gpus = [i for i in range(1)]
    with open(class_path) as f:
        class_list = f.read().splitlines()

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    if (USE_GITHASH):
        repo = git.Repo(search_parent_directories=True)
        sha = repo.head.object.hexsha
        short_sha = repo.git.rev_parse(sha, short=7)

    if USE_VISDOM:
        viz = visdom.Visdom(use_incoming_socket=False)
        vis_title = 'YOLOv2'
        vis_legend_Train = ['Train Loss']
        vis_legend_Val = ['Val Loss']
        iter_plot = create_vis_plot(viz, 'Iteration', 'Total Loss', vis_title,
                                    vis_legend_Train)
        val_plot = create_vis_plot(viz, 'Iteration', 'Validation Loss',
                                   vis_title, vis_legend_Val)

    # 2. Data augmentation setting
    if (USE_AUGMENTATION):
        seq = iaa.SomeOf(
            2,
            [
                iaa.Multiply(
                    (1.2, 1.5)),  # change brightness, doesn't affect BBs
                iaa.Affine(
                    translate_px={
                        "x": 3,
                        "y": 10
                    }, scale=(0.9, 0.9)
                ),  # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs
                iaa.AdditiveGaussianNoise(scale=0.1 * 255),
                iaa.CoarseDropout(0.02, size_percent=0.15, per_channel=0.5),
                iaa.Affine(rotate=45),
                iaa.Sharpen(alpha=0.5)
            ])
    else:
        seq = iaa.Sequential([])

    composed = transforms.Compose([Augmenter(seq)])

    # 3. Load Dataset
    # composed
    # transforms.ToTensor
    #TODO : Datalist가 있을때 VOC parsing
    # import pdb;pdb.set_trace()
    train_dataset = VOC(root=data_path,
                        transform=composed,
                        class_path=class_path,
                        datalist_path=datalist_path)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               collate_fn=detection_collate)
    val_dataset = VOC(root=val_data_path,
                      transform=composed,
                      class_path=class_path,
                      datalist_path=val_datalist_path)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             collate_fn=detection_collate)
    # 5. Load YOLOv2
    net = yolov2.YOLOv2()
    model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda()

    print("device : ", device)
    if device.type == 'cpu':
        model = torch.nn.DataParallel(net)
    else:
        model = torch.nn.DataParallel(net, device_ids=num_gpus).cuda()

    # 7.Train the model
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    # Train the model
    total_step = len(train_loader)

    total_train_step = num_epochs * total_step

    # for epoch in range(num_epochs):
    for epoch in range(1, num_epochs + 1):
        train_loss = 0
        total_val_loss = 0

        train_total_conf_loss = 0
        train_total_xy_loss = 0
        train_total_wh_loss = 0
        train_total_c_loss = 0

        val_total_conf_loss = 0
        val_total_xy_loss = 0
        val_total_wh_loss = 0
        val_total_c_loss = 0

        if (epoch % 500 == 0 and epoch < 1000):
            learning_rate /= 10
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=learning_rate,
                                         weight_decay=1e-5)
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                               gamma=0.9)

        if (epoch == 200) or (epoch == 400) or (epoch == 600) or (
                epoch == 20000) or (epoch == 30000):
            scheduler.step()
        model.train()
        for i, (images, labels, sizes) in enumerate(train_loader):

            current_train_step = (epoch) * total_step + (i + 1)

            if USE_GTCHECKER:
                visualize_GT(images, labels, class_list)

            images = images.to(device)
            labels = labels.to(device)

            dog = labels[0, 4, 7, :]
            human = labels[0, 6, 6, :]
            # Forward pass
            outputs = model(images)

            # Calc Loss
            one_loss, conf_loss, xy_loss, wh_loss, class_loss = detection_loss_4_yolo(
                outputs, labels, device.type)
            # objness1_loss = detection_loss_4_yolo(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            one_loss.backward()
            optimizer.step()
            train_loss += one_loss.item()
            train_total_conf_loss += conf_loss.item()
            train_total_xy_loss += xy_loss.item()
            train_total_wh_loss += wh_loss.item()
            train_total_c_loss += class_loss.item()

        train_total_conf_loss = train_total_conf_loss / len(train_loader)
        train_total_xy_loss = train_total_xy_loss / len(train_loader)
        train_total_wh_loss = train_total_wh_loss / len(train_loader)
        train_total_c_loss = train_total_c_loss / len(train_loader)
        train_epoch_loss = train_loss / len(train_loader)
        update_vis_plot(viz, epoch + 1, train_epoch_loss, iter_plot, None,
                        'append')

        model.eval()
        with torch.no_grad():

            for j, (v_images, v_labels, v_sizes) in enumerate(val_loader):
                v_images = v_images.to(device)
                v_labels = v_labels.to(device)
                # Forward pass
                v_outputs = model(v_images)

                # Calc Loss
                val_loss, conf_loss, xy_loss, wh_loss, class_loss = detection_loss_4_yolo(
                    v_outputs, v_labels, device.type)
                total_val_loss += val_loss.item()
                val_total_conf_loss += conf_loss.item()
                val_total_xy_loss += xy_loss.item()
                val_total_wh_loss += wh_loss.item()
                val_total_c_loss += class_loss.item()

            val_epoch_loss = total_val_loss / len(val_loader)
            val_total_conf_loss = val_total_conf_loss / len(val_loader)
            val_total_xy_loss = val_total_xy_loss / len(val_loader)
            val_total_wh_loss = val_total_wh_loss / len(val_loader)
            val_total_c_loss = val_total_c_loss / len(val_loader)
            update_vis_plot(viz, epoch + 1, val_epoch_loss, val_plot, None,
                            'append')

        if (((current_train_step) % 100)
                == 0) or (current_train_step % 1 == 0
                          and current_train_step < 300):
            print(
                'epoch: [{}/{}], total step: [{}/{}], batch step [{}/{}], lr: {},one_loss: {:.4f},val_loss: {:.4f}'
                .format(epoch + 1, num_epochs, current_train_step,
                        total_train_step, i + 1, total_step, ([
                            param_group['lr']
                            for param_group in optimizer.param_groups
                        ])[0], one_loss, val_loss))

        print('train loss', train_epoch_loss, 'val loss', val_epoch_loss)
        print('train conf loss', train_total_conf_loss, 'val conf loss',
              val_total_conf_loss)
        print('train xy loss', train_total_xy_loss, 'val xy loss',
              val_total_xy_loss)
        print('train wh loss', train_total_wh_loss, 'val wh loss',
              val_total_wh_loss)
        print('train class loss', train_total_c_loss, 'val class loss',
              val_total_c_loss)
        if not USE_GITHASH:
            short_sha = 'noHash'

        # if ((epoch % 1000) == 0) and (epoch != 0):
        # if ((epoch % 100) == 0) :
        if ((epoch % 10) == 0):
            #if (one_loss <= 1) :
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': "YOLOv2",
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                },
                False,
                filename=os.path.join(
                    checkpoint_path,
                    'ckpt_{}_ep{:05d}_loss{:.04f}_lr{}.pth.tar'.format(
                        short_sha, epoch, one_loss.item(), ([
                            param_group['lr']
                            for param_group in optimizer.param_groups
                        ])[0])))
            # print(dir(model))
            filename = os.path.join(
                checkpoint_path,
                'ckpt_{}_ep{:05d}_loss{:.04f}_lr{}model.pth.tar'.format(
                    short_sha, epoch, one_loss.item(), ([
                        param_group['lr']
                        for param_group in optimizer.param_groups
                    ])[0]))
            torch.save(model.module.state_dict(), filename)
Ejemplo n.º 15
0
    utils.set_seed(seed)

    ckptfilename = getCkptName(args.ckptname)
    ckptpath = os.path.join(ckptroot, modeltype, ckptfilename)
    print("CKPT PATH: {}".format(ckptpath))

    # Initialized reader, model and optimizer
    trainer = Trainer()
    print("Done modelinit")
    print("Mode : {}".format(mode))
    # print(trainer.tr_reader.ans2idx)

    if mode == 'train':
        (bestmodel, bestval, beststeps, steps) = trainer.train()
        print("Saving model: {}".format(ckptpath))
        utils.save_checkpoint(m=bestmodel,
                              o=trainer.optimizer,
                              steps=steps,
                              beststeps=beststeps,
                              path=ckptpath)
        pp.pprint(args)

    elif mode == 'val':
        utils.load_checkpoint(ckptpath, trainer.model, trainer.optimizer)
        trainer.validation()
        # (vt, vc, va) = trainer.validation_performance()
        # print("Total: {}. Validation Acc: {}".format(vt, va))

    sys.exit()
Ejemplo n.º 16
0
        loss.backward()
        optimaizer.step()

        print(
            'epoch: [{}/{}], total step:[{}/{}] , batchstep [{}/{}], lr: {},'
            'total_loss: {:.4f}, objness1: {:.4f}, class_loss: {:.4f}'.format(
                epoch, num_epochs, current_train_step, total_train_step, i + 1,
                total_step, learning_rate, loss.item(), obj_coord1_loss,
                obj_size1_loss, obj_class_loss))

    if (epoch % 2 == 0):
        '''
        torch.save({'test': epoch}, 'cc.zip')
        print("Saved...")

        '''
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': "YOLOv1",
                'state_dict': model.state_dict(),
            },
            False,
            filename=os.path.join(
                check_point_path, 'ep{:05d}_loss{:.04f}_lr{}.pth.tar'.format(
                    epoch,
                    loss.item(),
                    learning_rate,
                )))
        print("The check point is saved")
Ejemplo n.º 17
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    print('device : ' + device)

    # Get a summary writer for tensorboard
    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')

    #
    # Training the model with 13 classes of CamVid dataset
    # TODO: This process should be done only if specified
    #
    if not args.finetune:
        train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
            label_conversion=False)  # 13 classes
        args.use_depth = False  # 'use_depth' is always false for camvid

        print_info_message('Training samples: {}'.format(len(train_dataset)))
        print_info_message('Validation samples: {}'.format(len(val_dataset)))

        # Import model
        if args.model == 'espnetv2':
            from model.segmentation.espnetv2 import espnetv2_seg
            args.classes = seg_classes
            model = espnetv2_seg(args)
        elif args.model == 'espdnet':
            from model.segmentation.espdnet import espdnet_seg
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            print("Segmentation classes : {}".format(seg_classes))
            model = espdnet_seg(args)
        elif args.model == 'espdnetue':
            from model.segmentation.espdnet_ue import espdnetue_seg2
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            ("Segmentation classes : {}".format(seg_classes))
            print(args.weights)
            model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True)
        else:
            print_error_message('Arch: {} not yet supported'.format(
                args.model))
            exit(-1)

        # Freeze batch normalization layers?
        if args.freeze_bn:
            freeze_bn_layer(model)

        # Set learning rates
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

        # Define an optimizer
        optimizer = optim.SGD(train_params,
                              lr=args.lr * args.lr_mult,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        # Compute the FLOPs and the number of parameters, and display it
        num_params, flops = show_network_stats(model, crop_size)

        try:
            writer.add_graph(model,
                             input_to_model=torch.Tensor(
                                 1, 3, crop_size[0], crop_size[1]))
        except:
            print_log_message(
                "Not able to generate the graph. Likely because your model is not supported by ONNX"
            )

        #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
        criterion = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))
        nid_loss = NIDLoss(image_bin=32,
                           label_bin=seg_classes) if args.use_nid else None

        if num_gpus >= 1:
            if num_gpus == 1:
                # for a single GPU, we do not need DataParallel wrapper for Criteria.
                # So, falling back to its internal wrapper
                from torch.nn.parallel import DataParallel
                model = DataParallel(model)
                model = model.cuda()
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss.cuda()
            else:
                from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
                model = DataParallelModel(model)
                model = model.cuda()
                criterion = DataParallelCriteria(criterion)
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss = DataParallelCriteria(nid_loss)
                    nid_loss = nid_loss.cuda()

            if torch.backends.cudnn.is_available():
                import torch.backends.cudnn as cudnn
                cudnn.benchmark = True
                cudnn.deterministic = True

        # Get data loaders for training and validation data
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=20,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

        # Get a learning rate scheduler
        lr_scheduler = get_lr_scheduler(args.scheduler)

        write_stats_to_json(num_params, flops)

        extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
        #
        # Main training loop of 13 classes
        #
        start_epoch = 0
        best_miou = 0.0
        for epoch in range(start_epoch, args.epochs):
            lr_base = lr_scheduler.step(epoch)
            # set the optimizer with the learning rate
            # This can be done inside the MyLRScheduler
            lr_seg = lr_base * args.lr_mult
            optimizer.param_groups[0]['lr'] = lr_base
            optimizer.param_groups[1]['lr'] = lr_seg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # Use different training functions for espdnetue
            if args.model == 'espdnetue':
                from utilities.train_eval_seg import train_seg_ue as train
                from utilities.train_eval_seg import val_seg_ue as val
            else:
                from utilities.train_eval_seg import train_seg as train
                from utilities.train_eval_seg import val_seg as val

            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device,
                                           use_depth=args.use_depth,
                                           add_criterion=nid_loss)
            miou_val, val_loss = val(model,
                                     val_loader,
                                     criterion,
                                     seg_classes,
                                     device=device,
                                     use_depth=args.use_depth,
                                     add_criterion=nid_loss)

            batch_train = iter(train_loader).next()
            batch = iter(val_loader).next()
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)

            # remember best miou and save checkpoint
            is_best = miou_val > best_miou
            best_miou = max(miou_val, best_miou)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_miou': best_miou,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
            writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
            writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
            writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
            writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
            writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
            writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                              math.ceil(flops))
            writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                              math.ceil(num_params))

        # Save the pretrained weights
        model_dict = copy.deepcopy(model.state_dict())
        del model
        torch.cuda.empty_cache()

    #
    # Finetuning with 4 classes
    #
    args.ignore_idx = 4
    train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
        label_conversion=True)  # 5 classes

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    #set_parameters_for_finetuning()

    # Import model
    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espdnet':
        from model.segmentation.espdnet import espdnet_seg
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnet_seg(args)
    elif args.model == 'espdnetue':
        from model.segmentation.espdnet_ue import espdnetue_seg2
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        print(args.weights)
        model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if not args.finetune:
        new_model_dict = model.state_dict()
        #        for k, v in model_dict.items():
        #            if k.lstrip('module.') in new_model_dict:
        #                print('In:{}'.format(k.lstrip('module.')))
        #            else:
        #                print('Not In:{}'.format(k.lstrip('module.')))
        overlap_dict = {
            k.replace('module.', ''): v
            for k, v in model_dict.items()
            if k.replace('module.', '') in new_model_dict
            and new_model_dict[k.replace('module.', '')].size() == v.size()
        }
        no_overlap_dict = {
            k.replace('module.', ''): v
            for k, v in new_model_dict.items()
            if k.replace('module.', '') not in new_model_dict
            or new_model_dict[k.replace('module.', '')].size() != v.size()
        }
        print(no_overlap_dict.keys())

        new_model_dict.update(overlap_dict)
        model.load_state_dict(new_model_dict)

    output = model(torch.ones(1, 3, 288, 480))
    print(output[0].size())

    print(seg_classes)
    print(class_wts.size())
    #print(model_dict.keys())
    #print(new_model_dict.keys())
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))
    nid_loss = NIDLoss(image_bin=32,
                       label_bin=seg_classes) if args.use_nid else None

    # Set learning rates
    args.lr /= 100
    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]
    # Define an optimizer
    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss = DataParallelCriteria(nid_loss)
                nid_loss = nid_loss.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # Get data loaders for training and validation data
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=20,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    # Get a learning rate scheduler
    args.epochs = 50
    lr_scheduler = get_lr_scheduler(args.scheduler)

    # Compute the FLOPs and the number of parameters, and display it
    num_params, flops = show_network_stats(model, crop_size)
    write_stats_to_json(num_params, flops)

    extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s,
                                           crop_size[0])
    #
    # Main training loop of 13 classes
    #
    start_epoch = 0
    best_miou = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))

        # Use different training functions for espdnetue
        if args.model == 'espdnetue':
            from utilities.train_eval_seg import train_seg_ue as train
            from utilities.train_eval_seg import val_seg_ue as val
        else:
            from utilities.train_eval_seg import train_seg as train
            from utilities.train_eval_seg import val_seg as val

        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device,
                                       use_depth=args.use_depth,
                                       add_criterion=nid_loss)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device,
                                 use_depth=args.use_depth,
                                 add_criterion=nid_loss)

        batch_train = iter(train_loader).next()
        batch = iter(val_loader).next()
        in_training_visualization_img(model,
                                      images=batch_train[0].to(device=device),
                                      labels=batch_train[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/train',
                                      device=device)
        in_training_visualization_img(model,
                                      images=batch[0].to(device=device),
                                      labels=batch[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/val',
                                      device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch)
        writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch)
        writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch)
        writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch)
        writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('SegmentationConv/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()