Exemplo n.º 1
0
def modelDeploy(args, model, optimizer, scheduler, logger):
    if args.num_gpus >= 1:
        from torch.nn.parallel import DataParallel
        model = DataParallel(model)
        model = model.cuda()

    if torch.backends.cudnn.is_available():
        import torch.backends.cudnn as cudnn
        cudnn.benchmark = True
        cudnn.deterministic = True

    trainData = {'epoch': 0, 'loss': [], 'miou': [], 'val': [], 'bestMiou': 0}

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))

            # model&optimizer
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])

            # stop point
            trainData = checkpoint['trainData']
            for i in range(trainData['epoch']):
                scheduler.step()
            # print(trainData)

            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, trainData['epoch']))

        else:
            logger.error("=> no checkpoint found at '{}'".format(args.resume))
            assert False, "=> no checkpoint found at '{}'".format(args.resume)

    if args.finetune:
        if os.path.isfile(args.finetune):
            logger.info("=> finetuning checkpoint '{}'".format(args.finetune))
            state_all = torch.load(args.finetune, map_location='cpu')['model']
            state_clip = {}  # only use backbone parameters
            # print(model.state_dict().keys())
            for k, v in state_all.items():
                state_clip[k] = v
            # print(state_clip.keys())
            model.load_state_dict(state_clip, strict=False)
        else:
            logger.warning("finetune is not a file.")
            pass

    if args.freeze_bn:
        logger.warning('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    return model, trainData
Exemplo n.º 2
0
def load_reid_model():
    model = DataParallel(Model())
    ckpt = '/home/honglongcai/Github/PretrainedModel/model_410.pt'
    model.load_state_dict(torch.load(ckpt, map_location='cuda'))
    logger.info('Load ReID model from {}'.format(ckpt))

    model = model.cuda()
    model.eval()
    return model
Exemplo n.º 3
0
def create_model(model_name, num_classes):
    create_model_fn = {'resnet34': resnet34, 'resnet50': resnet50, 'C1': C1}
    assert model_name in create_model_fn.keys(), "must be one of {}".format(
        list(create_model_fn.keys()))
    logging.debug('\tCreating model {}'.format(model_name))
    model = DataParallel(create_model_fn[model_name](num_classes=num_classes))
    if CONFIG['general'].use_gpu:
        model = model.cuda()
    return model, dict(model.named_parameters())
Exemplo n.º 4
0
class TestProcess:
    def __init__(self):
        self.net = ET_Net()

        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())
        
        self.net.load_state_dict(torch.load(ARGS['weight']))

        self.test_dataset = get_dataset(dataset_name=ARGS['dataset'], part='test')

    def predict(self):

        start = time.time()
        self.net.eval()
        test_dataloader = DataLoader(self.test_dataset, batch_size=1) # only support batch size = 1
        os.makedirs(ARGS['prediction_save_folder'], exist_ok=True)
        for items in test_dataloader:
            images, mask, filename = items['image'], items['mask'], items['filename']
            images = images.float()
            mask = mask.long()
            print('image shape:', images.size())

            image_patches, big_h, big_w = get_test_patches(images, ARGS['crop_size'], ARGS['stride_size'])
            test_patch_dataloader = DataLoader(image_patches, batch_size=ARGS['batch_size'], shuffle=False, drop_last=False)
            test_results = []
            print('Number of batches for testing:', len(test_patch_dataloader))

            for patches in test_patch_dataloader:
                
                if ARGS['gpu']:
                    patches = patches.cuda()
                
                with torch.no_grad():
                    result_patches_edge, result_patches = self.net(patches)
                
                test_results.append(result_patches.cpu())           
            
            test_results = torch.cat(test_results, dim=0)
            # merge
            test_results = recompone_overlap(test_results, ARGS['crop_size'], ARGS['stride_size'], big_h, big_w)
            test_results = test_results[:, 1, :images.size(2), :images.size(3)] * mask
            test_results = Image.fromarray(test_results[0].numpy())
            test_results.save(os.path.join(ARGS['prediction_save_folder'], filename[0]))
            print(f'Finish prediction for {filename[0]}')

        finish = time.time()

        print('Predicting time consumed: {:.2f}s'.format(finish - start))
Exemplo n.º 5
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Exemplo n.º 6
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegCls, GREENHOUSE_CLASS_LIST
        train_dataset = GreenhouseRGBDSegCls(
            root=args.data_path,
            list_name='train_greenhouse_mult.txt',
            train=True,
            size=crop_size,
            scale=args.scale,
            use_depth=args.use_depth)
        val_dataset = GreenhouseRGBDSegCls(root=args.data_path,
                                           list_name='val_greenhouse_mult.txt',
                                           train=False,
                                           size=crop_size,
                                           scale=args.scale,
                                           use_depth=args.use_depth)
        class_weights = np.load('class_weights.npy')[:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        seg_classes = len(GREENHOUSE_CLASS_LIST)
        color_encoding = OrderedDict([('end_of_plant', (0, 255, 0)),
                                      ('other_part_of_plant', (0, 255, 255)),
                                      ('artificial_objects', (255, 0, 0)),
                                      ('ground', (255, 255, 0)),
                                      ('background', (0, 0, 0))])
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espdnet':
        from model.segmentation.espdnet_mult import espdnet_mult
        args.classes = seg_classes
        args.cls_classes = 5
        model = espdnet_mult(args)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if args.finetune:
        if os.path.isfile(args.finetune):
            print_info_message('Loading weights for finetuning from {}'.format(
                args.finetune))
            weight_dict = torch.load(args.finetune,
                                     map_location=torch.device(device='cpu'))
            model.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for finetuning. Please check.')

    if args.freeze_bn:
        print_info_message('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    if args.use_depth:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }, {
            'params': model.get_depth_encoder_params(),
            'lr': args.lr
        }]
    else:
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0
    best_miou = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            print_info_message("=> loading checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            best_miou = checkpoint['best_miou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_info_message("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_warning_message("=> no checkpoint found at '{}'".format(
                args.resume))

    print('device : ' + device)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    cls_class_weight = calc_cls_class_weight(train_loader, 5)
    print(cls_class_weight)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion_seg = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))

    criterion_cls = nn.CrossEntropyLoss(
        weight=torch.from_numpy(cls_class_weight).float().to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion_seg = criterion_seg.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion_seg = DataParallelCriteria(criterion_seg)
            criterion_seg = criterion_seg.cuda()

        criterion_cls = criterion_cls.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg
        if args.use_depth:
            optimizer.param_groups[2]['lr'] = lr_base

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))
        miou_train, train_loss, train_seg_loss, train_cls_loss = train(
            model,
            train_loader,
            optimizer,
            criterion_seg,
            seg_classes,
            epoch,
            criterion_cls,
            device=device,
            use_depth=args.use_depth)
        miou_val, val_loss, val_seg_loss, val_cls_loss = val(
            model,
            val_loader,
            criterion_seg,
            criterion_cls,
            seg_classes,
            device=device,
            use_depth=args.use_depth)

        batch = iter(val_loader).next()
        if args.use_depth:
            in_training_visualization_2(model,
                                        images=batch[0].to(device=device),
                                        depths=batch[2].to(device=device),
                                        labels=batch[1].to(device=device),
                                        class_encoding=color_encoding,
                                        writer=writer,
                                        epoch=epoch,
                                        data='Segmentation',
                                        device=device)
        else:
            in_training_visualization_2(model,
                                        images=batch[0].to(device=device),
                                        labels=batch[1].to(device=device),
                                        class_encoding=color_encoding,
                                        writer=writer,
                                        epoch=epoch,
                                        data='Segmentation',
                                        device=device)


#            image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
#            writer.add_image('Segmentation/results/val', image_grid, epoch)

# remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
        writer.add_scalar('Segmentation/SegLoss/train', train_seg_loss, epoch)
        writer.add_scalar('Segmentation/ClsLoss/train', train_cls_loss, epoch)
        writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
        writer.add_scalar('Segmentation/SegLoss/val', val_seg_loss, epoch)
        writer.add_scalar('Segmentation/ClsLoss/val', val_cls_loss, epoch)
        writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
        writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
        writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Exemplo n.º 7
0
else:
    primary_dataset = get_model('test')
    primary_data_loader = DataLoader(primary_dataset, shuffle=False, batch_size=1)

# Book-keeping

types = ', '.join([data_strs[i] for i in choice])
prep = lambda x: 'Number of %s in MRI study %s: %d\n'%('patients' if x.flag_3d else 'slices', types, len(x))
log_str = add_to_log(prep(primary_dataset))
log_str = add_to_log(prep(val_dataset))
# Define 3D UNet and train, val, test scripts

net = UNet3D()
net.train()

net = DataParallel(net.cuda())
bce_criterion = BCELoss()

def get_optimizer(st, lr, momentum=0.9):
    if st == 'sgd':
        return SGD(net.parameters(), lr = lr, momentum=momentum)
    elif st == 'adam':
        return Adam(net.parameters(), lr = lr)

optimizer = get_optimizer(args.optimizer, args.lr)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2)

def dice_loss(y, pred):
    smooth = 1.

    yflat = y.view(-1)
Exemplo n.º 8
0
class CalculateMetricProcess:
    def __init__(self):
        self.net = ET_Net()

        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())

        self.net.load_state_dict(torch.load(ARGS['weight']))

        self.metric_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                          part='metric')

    def predict(self):

        start = time.time()
        self.net.eval()
        metric_dataloader = DataLoader(
            self.metric_dataset, batch_size=1)  # only support batch size = 1
        os.makedirs(ARGS['prediction_save_folder'], exist_ok=True)
        y_true = []
        y_pred = []
        for items in metric_dataloader:
            images, labels, mask = items['image'], items['label'], items[
                'mask']
            images = images.float()
            print('image shape:', images.size())

            image_patches, big_h, big_w = get_test_patches(
                images, ARGS['crop_size'], ARGS['stride_size'])
            test_patch_dataloader = DataLoader(image_patches,
                                               batch_size=ARGS['batch_size'],
                                               shuffle=False,
                                               drop_last=False)
            test_results = []
            print('Number of batches for testing:', len(test_patch_dataloader))

            for patches in test_patch_dataloader:

                if ARGS['gpu']:
                    patches = patches.cuda()

                with torch.no_grad():
                    result_patches_edge, result_patches = self.net(patches)

                test_results.append(result_patches.cpu())

            test_results = torch.cat(test_results, dim=0)
            # merge
            test_results = recompone_overlap(test_results, ARGS['crop_size'],
                                             ARGS['stride_size'], big_h, big_w)
            test_results = test_results[:, 1, :images.size(2), :images.size(3)]
            y_pred.append(test_results[mask == 1].reshape(-1))
            y_true.append(labels[mask == 1].reshape(-1))

        y_pred = torch.cat(y_pred).numpy()
        y_true = torch.cat(y_true).numpy()
        calc_metrics(y_pred, y_true)
        finish = time.time()

        print('Calculating metric time consumed: {:.2f}s'.format(finish -
                                                                 start))
Exemplo n.º 9
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0

    elif args.dataset == 'greenhouse':
        print(args.use_depth)
        from data_loader.segmentation.greenhouse import GreenhouseRGBDSegmentation, GreenhouseDepth, GREENHOUSE_CLASS_LIST
        train_dataset = GreenhouseDepth(root=args.data_path,
                                        list_name='train_depth_ae.txt',
                                        train=True,
                                        size=crop_size,
                                        scale=args.scale,
                                        use_filter=True)
        val_dataset = GreenhouseRGBDSegmentation(root=args.data_path,
                                                 list_name='val_depth_ae.txt',
                                                 train=False,
                                                 size=crop_size,
                                                 scale=args.scale,
                                                 use_depth=True)
        class_weights = np.load('class_weights.npy')[:4]
        print(class_weights)
        class_wts = torch.from_numpy(class_weights).float().to(device)

        seg_classes = len(GREENHOUSE_CLASS_LIST)
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    from model.autoencoder.depth_autoencoder import espnetv2_autoenc
    args.classes = 3
    model = espnetv2_autoenc(args)

    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr * args.lr_mult
    }]

    optimizer = optim.SGD(train_params,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 1, crop_size[0], crop_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[0], crop_size[1], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')
    try:
        writer.add_graph(model,
                         input_to_model=torch.Tensor(1, 3, crop_size[0],
                                                     crop_size[1]))
    except:
        print_log_message(
            "Not able to generate the graph. Likely because your model is not supported by ONNX"
        )

    start_epoch = 0

    print('device : ' + device)

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    #criterion = SegmentationLoss(n_classes=seg_classes, loss_type=args.loss_type,
    #                             device=device, ignore_idx=args.ignore_idx,
    #                             class_wts=class_wts.to(device))
    criterion = nn.MSELoss()
    # criterion = nn.L1Loss()

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    if args.scheduler == 'fixed':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import FixedMultiStepLR
        lr_scheduler = FixedMultiStepLR(base_lr=args.lr,
                                        steps=step_sizes,
                                        gamma=args.lr_decay)
    elif args.scheduler == 'clr':
        step_size = args.step_size
        step_sizes = [
            step_size * i
            for i in range(1, int(math.ceil(args.epochs / step_size)))
        ]
        from utilities.lr_scheduler import CyclicLR
        lr_scheduler = CyclicLR(min_lr=args.lr,
                                cycle_len=5,
                                steps=step_sizes,
                                gamma=args.lr_decay)
    elif args.scheduler == 'poly':
        from utilities.lr_scheduler import PolyLR
        lr_scheduler = PolyLR(base_lr=args.lr,
                              max_epochs=args.epochs,
                              power=args.power)
    elif args.scheduler == 'hybrid':
        from utilities.lr_scheduler import HybirdLR
        lr_scheduler = HybirdLR(base_lr=args.lr,
                                max_epochs=args.epochs,
                                clr_max=args.clr_max,
                                cycle_len=args.cycle_len)
    elif args.scheduler == 'linear':
        from utilities.lr_scheduler import LinearLR
        lr_scheduler = LinearLR(base_lr=args.lr, max_epochs=args.epochs)
    else:
        print_error_message('{} scheduler Not supported'.format(
            args.scheduler))
        exit()

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
    best_loss = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_seg
        # optimizer.param_groups[1]['lr'] = lr_seg

        # Train
        model.train()
        losses = AverageMeter()
        for i, batch in enumerate(train_loader):
            inputs = batch[1].to(device=device)  # Depth
            target = batch[0].to(device=device)  # RGB

            outputs = model(inputs)

            if device == 'cuda':
                loss = criterion(outputs, target).mean()
                if isinstance(outputs, (list, tuple)):
                    target_dev = outputs[0].device
                    outputs = gather(outputs, target_device=target_dev)
            else:
                loss = criterion(outputs, target)

            losses.update(loss.item(), inputs.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #             if not (i % 10):
            #                 print("Step {}, write images".format(i))
            #                 image_grid = torchvision.utils.make_grid(outputs.data.cpu()).numpy()
            #                 writer.add_image('Autoencoder/results/train', image_grid, len(train_loader) * epoch + i)

            writer.add_scalar('Autoencoder/Loss/train', loss.item(),
                              len(train_loader) * epoch + i)

            print_info_message('Running batch {}/{} of epoch {}'.format(
                i + 1, len(train_loader), epoch + 1))

        train_loss = losses.avg

        writer.add_scalar('Autoencoder/LR/seg', round(lr_seg, 6), epoch)

        # Val
        if epoch % 5 == 0:
            losses = AverageMeter()
            with torch.no_grad():
                for i, batch in enumerate(val_loader):
                    inputs = batch[2].to(device=device)  # Depth
                    target = batch[0].to(device=device)  # RGB

                    outputs = model(inputs)

                    if device == 'cuda':
                        loss = criterion(outputs, target)  # .mean()
                        if isinstance(outputs, (list, tuple)):
                            target_dev = outputs[0].device
                            outputs = gather(outputs, target_device=target_dev)
                    else:
                        loss = criterion(outputs, target)

                    losses.update(loss.item(), inputs.size(0))

                    image_grid = torchvision.utils.make_grid(
                        outputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/results/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        inputs.data.cpu()).numpy()
                    writer.add_image('Autoencoder/inputs/val', image_grid,
                                     epoch)
                    image_grid = torchvision.utils.make_grid(
                        target.data.cpu()).numpy()
                    writer.add_image('Autoencoder/target/val', image_grid,
                                     epoch)

            val_loss = losses.avg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # remember best miou and save checkpoint
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Autoencoder/Loss/val', val_loss, epoch)

    writer.close()
Exemplo n.º 10
0
def main(args):
    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    print("===> Loading datasets")
    data_set = EvalDataset(
        args.test_lr,
        n_frames=args.n_frames,
        interval_list=args.interval_list,
    )
    eval_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             num_workers=args.workers)

    #### random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    #cudnn.deterministic = True

    print("===> Building model")
    #### create model
    model = EDVR_arch.EDVR(nf=args.nf,
                           nframes=args.n_frames,
                           groups=args.groups,
                           front_RBs=args.front_RBs,
                           back_RBs=args.back_RBs,
                           center=args.center,
                           predeblur=args.predeblur,
                           HR_in=args.HR_in,
                           w_TSA=args.w_TSA)
    print("===> Setting GPU")
    gups = args.gpus if args.gpus != 0 else torch.cuda.device_count()
    device_ids = list(range(gups))
    model = DataParallel(model, device_ids=device_ids)
    model = model.cuda()

    # print(model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            # 获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']

            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

    #### training
    print("===> Eval")
    model.eval()
    with tqdm(total=(len(data_set) - len(data_set) % args.batch_size)) as t:
        for data in eval_loader:
            data_x = data['LRs'].cuda()
            names = data['files']

            with torch.no_grad():
                outputs = model(data_x).data.float().cpu()
            outputs = outputs * 255.
            outputs = outputs.clamp_(0, 255).numpy()
            for img, file in zip(outputs, names):
                img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0))
                img = img.round()

                arr = file.split('/')
                dst_dir = os.path.join(args.outputs_dir, arr[-2])
                if not os.path.exists(dst_dir):
                    os.makedirs(dst_dir)
                dst_name = os.path.join(dst_dir, arr[-1])

                cv2.imwrite(dst_name, img)
            t.update(len(names))
Exemplo n.º 11
0
def main(args):
    # read all the images in the folder
    if args.dataset == 'city':
        # image_path = os.path.join(args.data_path, "leftImg8bit", args.split, "*", "*.png")
        # image_list = glob.glob(image_path)
        # from data_loader.segmentation.cityscapes import CITYSCAPE_CLASS_LIST
        # seg_classes = len(CITYSCAPE_CLASS_LIST)
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=(256, 256),
                                             scale=args.s,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    elif args.dataset == 'pascal':
        # from data_loader.segmentation.voc import VOC_CLASS_LIST
        # seg_classes = len(VOC_CLASS_LIST)
        # data_file = os.path.join(args.data_path, 'VOC2012', 'list', '{}.txt'.format(args.split))
        # if not os.path.isfile(data_file):
        #     print_error_message('{} file does not exist'.format(data_file))
        # image_list = []
        # with open(data_file, 'r') as lines:
        #     for line in lines:
        #         rgb_img_loc = '{}/{}/{}'.format(args.data_path, 'VOC2012', line.split()[0])
        #         if not os.path.isfile(rgb_img_loc):
        #             print_error_message('{} image file does not exist'.format(rgb_img_loc))
        #         image_list.append(rgb_img_loc)
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=(256, 256),
                                      scale=args.s)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'hockey':
        from data_loader.segmentation.hockey import HockeySegmentationDataset, HOCKEY_DATASET_CLASS_LIST
        train_dataset = HockeySegmentationDataset(root=args.data_path,
                                                  train=True,
                                                  crop_size=(256, 256),
                                                  scale=args.s)
        val_dataset = HockeySegmentationDataset(root=args.data_path,
                                                train=False,
                                                crop_size=(256, 256),
                                                scale=args.s)
        seg_classes = len(HOCKEY_DATASET_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'hockey_rink_seg':
        from data_loader.segmentation.hockey_rink_seg import HockeyRinkSegmentationDataset, HOCKEY_DATASET_CLASS_LIST
        train_dataset = HockeyRinkSegmentationDataset(root=args.data_path,
                                                      train=True,
                                                      crop_size=(256, 256),
                                                      scale=args.s)
        val_dataset = HockeyRinkSegmentationDataset(root=args.data_path,
                                                    train=False,
                                                    crop_size=(256, 256),
                                                    scale=args.s)
        seg_classes = len(HOCKEY_DATASET_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    else:
        print_error_message('{} dataset not yet supported'.format(
            args.dataset))

    if len(val_dataset) == 0:
        print_error_message('No files in directory: {}'.format(image_path))

    print_info_message('# of images for testing: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'dicenet':
        from model.segmentation.dicenet import dicenet_seg
        model = dicenet_seg(args, classes=seg_classes)
    else:
        print_error_message('{} network not yet supported'.format(args.model))
        exit(-1)

    # model information
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, args.im_size[0],
                                             args.im_size[1]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            args.im_size[0], args.im_size[1], flops))
    print_info_message('# of parameters: {}'.format(num_params))

    if args.weights_test:
        print_info_message('Loading model weights')
        weight_dict = torch.load(args.weights_test,
                                 map_location=torch.device('cpu'))

        if isinstance(weight_dict, dict) and 'state_dict' in weight_dict:
            model.load_state_dict(weight_dict['state_dict'])
        else:
            model.load_state_dict(weight_dict)

        print_info_message('Weight loaded successfully')
    else:
        print_error_message(
            'weight file does not exist or not specified. Please check: {}',
            format(args.weights_test))

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    model = model.to(device=device)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=40,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=4)

    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type='ce',
                                 device=device,
                                 ignore_idx=255,
                                 class_wts=class_wts.to(device))
    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # evaluate(args, model, image_list, seg_classes, device=device)
    miou_val, val_loss = val(model,
                             val_loader,
                             criterion,
                             seg_classes,
                             device=device)
    print_info_message('mIOU: {}'.format(miou_val))
Exemplo n.º 12
0
class VisualizeProcess:
    def __init__(self):
        self.net = ET_Net()

        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())

        self.net.load_state_dict(torch.load(ARGS['weight']))

        self.train_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                         part='train')
        self.val_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                       part='val')

    def visualize(self):

        start = time.time()
        self.net.eval()
        val_batch_size = min(ARGS['batch_size'], len(self.val_dataset))
        val_dataloader = DataLoader(self.val_dataset,
                                    batch_size=val_batch_size)
        for batch_index, items in enumerate(val_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            images = images.float()
            labels = labels.long()
            edges = edges.long()

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            print('image shape:', images.size())

            with torch.no_grad():
                outputs_edge, outputs = self.net(images)

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred[0] & labels[0]) / (
                torch.sum(pred[0] | labels[0]) + 1e-6)

            mean = torch.FloatTensor([123.68, 116.779, 103.939]).reshape(
                (3, 1, 1)) / 255.
            images = images + mean.cuda()

            # images *= 255.
            print('pred min: ', pred[0].min(), ' max: ', pred[0].max())
            print('label min:', labels[0].min(), ' max: ', labels[0].max())
            print('edge min:', edges[0].min(), ' max: ', edges[0].max())
            print('output edge min:', outputs_edge[0].min(), ' max: ',
                  outputs_edge[0].max())
            print('IoU:', iou)
            print('Intersect num:', torch.sum(pred[0] & labels[0]))
            print('Union num:', torch.sum(pred[0] | labels[0]))

            plt.subplot(221)
            plt.imshow(images[0].cpu().numpy().transpose(
                (1, 2, 0))), plt.axis('off')
            plt.subplot(222)
            plt.imshow(labels[0].cpu().numpy(), cmap='gray'), plt.axis('off')
            plt.subplot(223)
            # plt.imshow(pred[0].cpu().numpy(), cmap='gray'), plt.axis('off')
            plt.imshow(outputs[0, 1].cpu().numpy(),
                       cmap='gray'), plt.axis('off')
            plt.subplot(224)
            plt.imshow(outputs_edge[0, 1].cpu().numpy(),
                       cmap='gray'), plt.axis('off')
            plt.show()

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        finish = time.time()

        print('validating time consumed: {:.2f}s'.format(finish - start))
Exemplo n.º 13
0
class BaseEngine(object):
    def __init__(self, args):
        self._make_dataset(args)
        self._make_model(args)
        tc.manual_seed(args.seed)
        if args.cuda and tc.cuda.is_available():
            tc.cuda.manual_seed_all(args.seed)
            if tc.cuda.device_count() > 1:
                self.batch_size = args.batch_size * tc.cuda.device_count()
                self.model = DataParallel(self.model)
            else:
                self.batch_size = args.batch_size
                self.model = self.model.cuda()
        else:
            self.batch_size = args.batch_size
        self._make_optimizer(args)
        self._make_loss(args)
        self._make_metric(args)
        self.num_training_samples = args.num_training_samples
        self.tag = args.tag or 'default'
        self.dump_dir = get_dir(args.dump_dir)
        self.train_logger = get_logger('train.{}.{}'.format(
            self.__class__.__name__, self.tag))

    def _make_dataset(self, args):
        raise NotImplementedError

    def _make_model(self, args):
        raise NotImplementedError

    def _make_optimizer(self, args):
        raise NotImplementedError

    def _make_loss(self, args):
        raise NotImplementedError

    def _make_metric(self, args):
        raise NotImplementedError

    def dump(self, epoch, model=True, optimizer=True, decayer=True):
        state = {'epoch': epoch}
        if model:
            state['model'] = self.model.state_dict()
        if optimizer:
            state['optimizer'] = self.optimizer.state_dict()
        if decayer and (getattr(self, 'decayer', None) is not None):
            state['decayer'] = self.decayer.state_dict()
        tc.save(state,
                os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag)))
        self.train_logger.info('Checkpoint {} dumped'.format(self.tag))

    def load(self, model=True, optimizer=True, decayer=True):
        try:
            state = tc.load(
                os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag)))
        except FileNotFoundError:
            return 0
        if model and (state.get('model') is not None):
            self.model.load_state_dict(state['model'])
        if optimizer and (state.get('optimizer') is not None):
            self.optimizer.load_state_dict(state['optimizer'])
        if decayer and (state.get('decayer') is not None) and (getattr(
                self, 'decayer', None) is not None):
            self.decayer.load_state_dict(state['decayer'])
        return state['epoch']

    def eval(self):
        raise NotImplementedError

    def test(self):
        raise NotImplementedError

    def train(self, num_epochs, resume=False):
        raise NotImplementedError
Exemplo n.º 14
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    print('device : ' + device)

    # Get a summary writer for tensorboard
    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')

    #
    # Training the model with 13 classes of CamVid dataset
    # TODO: This process should be done only if specified
    #
    if not args.finetune:
        train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
            label_conversion=False)  # 13 classes
        args.use_depth = False  # 'use_depth' is always false for camvid

        print_info_message('Training samples: {}'.format(len(train_dataset)))
        print_info_message('Validation samples: {}'.format(len(val_dataset)))

        # Import model
        if args.model == 'espnetv2':
            from model.segmentation.espnetv2 import espnetv2_seg
            args.classes = seg_classes
            model = espnetv2_seg(args)
        elif args.model == 'espdnet':
            from model.segmentation.espdnet import espdnet_seg
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            print("Segmentation classes : {}".format(seg_classes))
            model = espdnet_seg(args)
        elif args.model == 'espdnetue':
            from model.segmentation.espdnet_ue import espdnetue_seg2
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            ("Segmentation classes : {}".format(seg_classes))
            print(args.weights)
            model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True)
        else:
            print_error_message('Arch: {} not yet supported'.format(
                args.model))
            exit(-1)

        # Freeze batch normalization layers?
        if args.freeze_bn:
            freeze_bn_layer(model)

        # Set learning rates
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

        # Define an optimizer
        optimizer = optim.SGD(train_params,
                              lr=args.lr * args.lr_mult,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        # Compute the FLOPs and the number of parameters, and display it
        num_params, flops = show_network_stats(model, crop_size)

        try:
            writer.add_graph(model,
                             input_to_model=torch.Tensor(
                                 1, 3, crop_size[0], crop_size[1]))
        except:
            print_log_message(
                "Not able to generate the graph. Likely because your model is not supported by ONNX"
            )

        #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
        criterion = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))
        nid_loss = NIDLoss(image_bin=32,
                           label_bin=seg_classes) if args.use_nid else None

        if num_gpus >= 1:
            if num_gpus == 1:
                # for a single GPU, we do not need DataParallel wrapper for Criteria.
                # So, falling back to its internal wrapper
                from torch.nn.parallel import DataParallel
                model = DataParallel(model)
                model = model.cuda()
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss.cuda()
            else:
                from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
                model = DataParallelModel(model)
                model = model.cuda()
                criterion = DataParallelCriteria(criterion)
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss = DataParallelCriteria(nid_loss)
                    nid_loss = nid_loss.cuda()

            if torch.backends.cudnn.is_available():
                import torch.backends.cudnn as cudnn
                cudnn.benchmark = True
                cudnn.deterministic = True

        # Get data loaders for training and validation data
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=20,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

        # Get a learning rate scheduler
        lr_scheduler = get_lr_scheduler(args.scheduler)

        write_stats_to_json(num_params, flops)

        extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
        #
        # Main training loop of 13 classes
        #
        start_epoch = 0
        best_miou = 0.0
        for epoch in range(start_epoch, args.epochs):
            lr_base = lr_scheduler.step(epoch)
            # set the optimizer with the learning rate
            # This can be done inside the MyLRScheduler
            lr_seg = lr_base * args.lr_mult
            optimizer.param_groups[0]['lr'] = lr_base
            optimizer.param_groups[1]['lr'] = lr_seg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # Use different training functions for espdnetue
            if args.model == 'espdnetue':
                from utilities.train_eval_seg import train_seg_ue as train
                from utilities.train_eval_seg import val_seg_ue as val
            else:
                from utilities.train_eval_seg import train_seg as train
                from utilities.train_eval_seg import val_seg as val

            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device,
                                           use_depth=args.use_depth,
                                           add_criterion=nid_loss)
            miou_val, val_loss = val(model,
                                     val_loader,
                                     criterion,
                                     seg_classes,
                                     device=device,
                                     use_depth=args.use_depth,
                                     add_criterion=nid_loss)

            batch_train = iter(train_loader).next()
            batch = iter(val_loader).next()
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)

            # remember best miou and save checkpoint
            is_best = miou_val > best_miou
            best_miou = max(miou_val, best_miou)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_miou': best_miou,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
            writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
            writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
            writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
            writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
            writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
            writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                              math.ceil(flops))
            writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                              math.ceil(num_params))

        # Save the pretrained weights
        model_dict = copy.deepcopy(model.state_dict())
        del model
        torch.cuda.empty_cache()

    #
    # Finetuning with 4 classes
    #
    args.ignore_idx = 4
    train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
        label_conversion=True)  # 5 classes

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    #set_parameters_for_finetuning()

    # Import model
    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espdnet':
        from model.segmentation.espdnet import espdnet_seg
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnet_seg(args)
    elif args.model == 'espdnetue':
        from model.segmentation.espdnet_ue import espdnetue_seg2
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        print(args.weights)
        model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if not args.finetune:
        new_model_dict = model.state_dict()
        #        for k, v in model_dict.items():
        #            if k.lstrip('module.') in new_model_dict:
        #                print('In:{}'.format(k.lstrip('module.')))
        #            else:
        #                print('Not In:{}'.format(k.lstrip('module.')))
        overlap_dict = {
            k.replace('module.', ''): v
            for k, v in model_dict.items()
            if k.replace('module.', '') in new_model_dict
            and new_model_dict[k.replace('module.', '')].size() == v.size()
        }
        no_overlap_dict = {
            k.replace('module.', ''): v
            for k, v in new_model_dict.items()
            if k.replace('module.', '') not in new_model_dict
            or new_model_dict[k.replace('module.', '')].size() != v.size()
        }
        print(no_overlap_dict.keys())

        new_model_dict.update(overlap_dict)
        model.load_state_dict(new_model_dict)

    output = model(torch.ones(1, 3, 288, 480))
    print(output[0].size())

    print(seg_classes)
    print(class_wts.size())
    #print(model_dict.keys())
    #print(new_model_dict.keys())
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))
    nid_loss = NIDLoss(image_bin=32,
                       label_bin=seg_classes) if args.use_nid else None

    # Set learning rates
    args.lr /= 100
    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]
    # Define an optimizer
    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss = DataParallelCriteria(nid_loss)
                nid_loss = nid_loss.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # Get data loaders for training and validation data
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=20,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    # Get a learning rate scheduler
    args.epochs = 50
    lr_scheduler = get_lr_scheduler(args.scheduler)

    # Compute the FLOPs and the number of parameters, and display it
    num_params, flops = show_network_stats(model, crop_size)
    write_stats_to_json(num_params, flops)

    extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s,
                                           crop_size[0])
    #
    # Main training loop of 13 classes
    #
    start_epoch = 0
    best_miou = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))

        # Use different training functions for espdnetue
        if args.model == 'espdnetue':
            from utilities.train_eval_seg import train_seg_ue as train
            from utilities.train_eval_seg import val_seg_ue as val
        else:
            from utilities.train_eval_seg import train_seg as train
            from utilities.train_eval_seg import val_seg as val

        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device,
                                       use_depth=args.use_depth,
                                       add_criterion=nid_loss)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device,
                                 use_depth=args.use_depth,
                                 add_criterion=nid_loss)

        batch_train = iter(train_loader).next()
        batch = iter(val_loader).next()
        in_training_visualization_img(model,
                                      images=batch_train[0].to(device=device),
                                      labels=batch_train[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/train',
                                      device=device)
        in_training_visualization_img(model,
                                      images=batch[0].to(device=device),
                                      labels=batch[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/val',
                                      device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch)
        writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch)
        writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch)
        writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch)
        writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('SegmentationConv/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
Exemplo n.º 15
0
def main(args):
    logdir = args.savedir + '/logs/'
    if not os.path.isdir(logdir):
        os.makedirs(logdir)

    my_logger = Logger(60066, logdir)

    if args.dataset == 'pascal':
        crop_size = (512, 512)
        args.scale = (0.5, 2.0)
    elif args.dataset == 'city':
        crop_size = (768, 768)
        args.scale = (0.5, 2.0)

    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[1], crop_size[0], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espnet':
        from model.espnet import espnet_seg
        args.classes = seg_classes
        model = espnet_seg(args)
    elif args.model == 'mobilenetv2_1_0':
        from model.mobilenetv2 import get_mobilenet_v2_1_0_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_1_0_seg(args)
    elif args.model == 'mobilenetv2_0_35':
        from model.mobilenetv2 import get_mobilenet_v2_0_35_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_35_seg(args)
    elif args.model == 'mobilenetv2_0_5':
        from model.mobilenetv2 import get_mobilenet_v2_0_5_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_5_seg(args)
    elif args.model == 'mobilenetv3_small':
        from model.mobilenetv3 import get_mobilenet_v3_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_small_seg(args)
    elif args.model == 'mobilenetv3_large':
        from model.mobilenetv3 import get_mobilenet_v3_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_large_seg(args)
    elif args.model == 'mobilenetv3_RE_small':
        from model.mobilenetv3 import get_mobilenet_v3_RE_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_small_seg(args)
    elif args.model == 'mobilenetv3_RE_large':
        from model.mobilenetv3 import get_mobilenet_v3_RE_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_large_seg(args)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = []
    params_dict = dict(model.named_parameters())
    others = args.weight_decay * 0.01
    for key, value in params_dict.items():
        if len(value.data.shape) == 4:
            if value.data.shape[1] == 1:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': 0.0
                }]
            else:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': args.weight_decay
                }]
        else:
            train_params += [{
                'params': [value],
                'lr': args.lr,
                'weight_decay': others
            }]

    args.learning_rate = args.lr
    optimizer = get_optimizer(args.optimizer, train_params, args)
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[1], crop_size[0]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[1], crop_size[0], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    start_epoch = 0
    epochs_len = args.epochs
    best_miou = 0.0

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers,
                                               drop_last=True)
    if args.dataset == 'city':
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)
    else:
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)

    lr_scheduler = get_lr_scheduler(args)

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])

    if args.fp_epochs > 0:
        print_info_message("========== MODEL FP WARMUP ===========")

        for epoch in range(args.fp_epochs):
            lr = lr_scheduler.step(epoch)

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            print_info_message(
                'Running epoch {} with learning rates: {:.6f}'.format(
                    epoch, lr))
            start_t = time.time()
            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device)
    if args.optimizer.startswith('Q'):
        optimizer.is_warmup = False
        print('exp_sensitivity calibration fin.')

    if not args.fp_train:
        model.module.quantized.fuse_model()
        model.module.quantized.qconfig = torch.quantization.get_default_qat_qconfig(
            'qnnpack')
        torch.quantization.prepare_qat(model.module.quantized, inplace=True)

    if args.resume:
        start_epoch = args.start_epoch
        if os.path.isfile(args.resume):
            print_info_message('Loading weights from {}'.format(args.resume))
            weight_dict = torch.load(args.resume, device)
            model.module.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for resume. Please check.')

    for epoch in range(start_epoch, args.epochs):
        lr = lr_scheduler.step(epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        print_info_message(
            'Running epoch {} with learning rates: {:.6f}'.format(epoch, lr))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)
        if is_best:
            model_file_name = args.savedir + '/model_' + str(epoch +
                                                             1) + '.pth'
            torch.save(weights_dict, model_file_name)
            print('weights saved in {}'.format(model_file_name))
        info = {
            'Segmentation/LR': round(lr, 6),
            'Segmentation/Loss/train': train_loss,
            'Segmentation/Loss/val': val_loss,
            'Segmentation/mIOU/train': miou_train,
            'Segmentation/mIOU/val': miou_val,
            'Segmentation/Complexity/Flops': best_miou,
            'Segmentation/Complexity/Params': best_miou,
        }

        for tag, value in info.items():
            if tag == 'Segmentation/Complexity/Flops':
                my_logger.scalar_summary(tag, value, math.ceil(flops))
            elif tag == 'Segmentation/Complexity/Params':
                my_logger.scalar_summary(tag, value, math.ceil(num_params))
            else:
                my_logger.scalar_summary(tag, value, epoch + 1)

    print_info_message("========== TRAINING FINISHED ===========")
Exemplo n.º 16
0
                          train_cfg=cfg.train_cfg,
                          test_cfg=cfg.test_cfg)
    logger.info('-' * 20 + 'finish build model' + '-' * 20)
    logger.info('Total Parameters: %d,   Trainable Parameters: %s',
                model.net_parameters['Total'],
                str(model.net_parameters['Trainable']))
    # build dataset
    datasets = build_dataset(cfg.data.train)
    logger.info('-' * 20 + 'finish build dataset' + '-' * 20)
    # put model on gpu
    if torch.cuda.is_available():
        if len(cfg.gpu_ids) == 1:
            model = model.cuda()
            logger.info('-' * 20 + 'model to one gpu' + '-' * 20)
        else:
            model = DataParallel(model.cuda(), device_ids=cfg.gpu_ids)
            logger.info('-' * 20 + 'model to multi gpus' + '-' * 20)
    # create data_loader
    data_loader = build_dataloader(datasets, cfg.data.samples_per_gpu,
                                   cfg.data.workers_per_gpu, len(cfg.gpu_ids))
    logger.info('-' * 20 + 'finish build dataloader' + '-' * 20)
    # create optimizer
    optimizer = build_optimizer(model, cfg.optimizer)
    Scheduler = build_scheduler(cfg.lr_config)
    logger.info('-' * 20 + 'finish build optimizer' + '-' * 20)

    visualizer = Visualizer()
    vis = visdom.Visdom()
    criterion_ssim_loss = build_loss(cfg.loss_ssim)
    criterion_l1_loss = build_loss(cfg.loss_l1)
    ite_num = 0
Exemplo n.º 17
0
def train(**args):
    gpu = list(map(int, args['gpu']))
    print("* Using GPU - %s\n" % (str(gpu)))

    # Model
    model = lambda_resnet50(num_classes=10)

    # Set device & Data Parallelization
    torch.cuda.set_device(gpu[0])
    if len(gpu) > 1:  # Using multiple-GPUs
        model = DataParallel(model, device_ids=gpu)
    model.cuda()

    # Loss
    criterion = nn.CrossEntropyLoss()

    # Dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset_train = datasets.CIFAR10(root='./data',
                                     train=True,
                                     download=True,
                                     transform=transform)
    dataset_valid = datasets.CIFAR10(root='./data',
                                     train=False,
                                     download=True,
                                     transform=transform)

    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  batch_size=args['batch_size'],
                                  num_workers=len(gpu) * 4)
    dataloader_valid = DataLoader(dataset_valid,
                                  shuffle=False,
                                  batch_size=args['batch_size'],
                                  num_workers=len(gpu) * 4)

    # Optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args['lr'],
                           weight_decay=args['weight_decay'])

    # Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                               T_0=10,
                                                               T_mult=2,
                                                               eta_min=0.0001)

    for epoch in range(args['num_epochs']):
        """ Training iteration """
        model.train()
        for batch_idx, (samples, labels) in enumerate(dataloader_train):
            optimizer.zero_grad()

            if gpu:
                samples = samples.cuda()
                labels = labels.cuda()

            logits = model(samples)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            if batch_idx % args['log_interval'] == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(samples),
                    len(dataloader_train.dataset),
                    100. * batch_idx / len(dataloader_train), loss.item()))
        scheduler.step()
        """ Validation iteration """
        model.eval()
        with torch.no_grad():
            valid_loss = 0
            correct = 0
            for samples, labels in dataloader_valid:
                if gpu:
                    samples = samples.cuda()
                    labels = labels.cuda()

                logits = model(samples)
                valid_loss += criterion(logits, labels)
                preds = logits.argmax(dim=1, keepdim=True)
                correct += preds.eq(labels.view_as(preds)).sum().item()
            valid_loss /= len(dataloader_valid)
            print(
                '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
                .format(valid_loss, correct, len(dataloader_valid.dataset),
                        100. * correct / len(dataloader_valid.dataset)))
Exemplo n.º 18
0
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)


model = DataParallel(Net())
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss().cuda()

model.train()
for batch_idx, (data, target) in enumerate(train_loader):
    input_var = Variable(data.cuda())
    target_var = Variable(target.cuda())

    print('Getting model output')
    output = model(input_var)
    print('Got model output')

    loss = criterion(output, target_var)
    optimizer.zero_grad()
Exemplo n.º 19
0
def main(args):
    print("===> Loading datasets")
    data_set = DatasetLoader(args.data_lr,
                             args.data_hr,
                             size_w=args.size_w,
                             size_h=args.size_h,
                             scale=args.scale,
                             n_frames=args.n_frames,
                             interval_list=args.interval_list,
                             border_mode=args.border_mode,
                             random_reverse=args.random_reverse)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=False,
                              drop_last=True)

    #### random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    #cudnn.deterministic = True

    print("===> Building model")
    #### create model
    model = EDVR_arch.EDVR(nf=args.nf,
                           nframes=args.n_frames,
                           groups=args.groups,
                           front_RBs=args.front_RBs,
                           back_RBs=args.back_RBs,
                           center=args.center,
                           predeblur=args.predeblur,
                           HR_in=args.HR_in,
                           w_TSA=args.w_TSA)
    criterion = CharbonnierLoss()
    print("===> Setting GPU")
    gups = args.gpus if args.gpus != 0 else torch.cuda.device_count()
    device_ids = list(range(gups))
    model = DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    # print(model)

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            # 获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']

            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            # 如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

        # 如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch if args.start_epoch != 0 else start_epoch

    #如果use_current_lr大于0 测代替作为lr
    args.lr = args.use_current_lr if args.use_current_lr > 0 else args.lr
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 betas=(args.beta1, args.beta2),
                                 eps=1e-8)

    #### training
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)
        if args.use_tqdm == 1:
            losses, psnrs = one_epoch_train_tqdm(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])
        else:
            losses, psnrs = one_epoch_train_logger(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])

        # save model
        # if epoch %9 != 0:
        #     continue

        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edvr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)
def main():

    global args, best_prec1
    args = parser.parse_args()

    # Read list of training and validation data
    listfiles_train, labels_train = read_lists(TRAIN_OUT)
    listfiles_val, labels_val = read_lists(VAL_OUT)
    listfiles_test, labels_test = read_lists(TEST_OUT)
    dataset_train = Dataset(listfiles_train,
                            labels_train,
                            subtract_mean=False,
                            V=12)
    dataset_val = Dataset(listfiles_val, labels_val, subtract_mean=False, V=12)
    dataset_test = Dataset(listfiles_test,
                           labels_test,
                           subtract_mean=False,
                           V=12)

    # shuffle data
    dataset_train.shuffle()
    dataset_val.shuffle()
    dataset_test.shuffle()
    tra_data_size, val_data_size, test_data_size = dataset_train.size(
    ), dataset_val.size(), dataset_test.size()
    print 'training size:', tra_data_size
    print 'validation size:', val_data_size
    print 'testing size:', test_data_size

    batch_size = args.b
    print("batch_size is :" + str(batch_size))
    learning_rate = args.lr
    print("learning_rate is :" + str(learning_rate))
    num_cuda = cuda.device_count()
    print("number of GPUs have been detected:" + str(num_cuda))

    # creat model
    print("model building...")
    mvcnn = DataParallel(modelnet40_Alex(num_cuda, batch_size))
    #mvcnn = modelnet40(num_cuda, batch_size, multi_gpu = False)
    mvcnn.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint'{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            mvcnn.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    #print(mvcnn)

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adadelta(mvcnn.parameters(), weight_decay=1e-4)
    # evaluate performance only
    if args.evaluate:
        print 'testing mode ------------------'
        validate(dataset_test, mvcnn, criterion, optimizer, batch_size)
        return

    print 'training mode ------------------'
    for epoch in xrange(args.start_epoch, args.epochs):
        print('epoch:', epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(dataset_train, mvcnn, criterion, optimizer, epoch, batch_size)

        # evaluate on validation set
        prec1 = validate(dataset_val, mvcnn, criterion, optimizer, batch_size)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
        elif epoch % 5 is 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
Exemplo n.º 21
0

model = load_network(model_structure)

#optimizer_ft = optim.SGD(model.parameters(), lr = 0.0, momentum=0.0, weight_decay=0)
optimizer_ft = optim.SGD(model.parameters(),
                         lr=0.0001,
                         momentum=0.9,
                         weight_decay=5e-4)

exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft,
                                             step_size=args.lr_decay_epochs,
                                             gamma=0.1)

model = DataParallel(model)
model = model.cuda()


def save_network(network, epoch_label):
    save_filename = 'net_%s.pth' % epoch_label
    save_path = os.path.join(args.model_save_dir, save_filename)
    if not os.path.exists(args.model_save_dir):
        os.mkdir(args.model_save_dir)
    torch.save(network.state_dict(), save_path)


def train_model(model, optimizer, scheduler, num_epochs):

    scheduler.step()
    model.train()
Exemplo n.º 22
0
def load_network(network):
    save_path = os.path.join(args.pretrained_path, 'pretrained_weight.pth')
    network.load_state_dict({'model.' + k : v for k,v in remove_fc(torch.load(save_path)).items()}, strict = False)
    return network

model = load_network(model_structure)
model2 = load_network(model_structure)

optimizer_ft = optim.SGD(model.parameters(), lr = 0.0001, momentum=0.9, weight_decay=5e-4)
optimizer_ft2 = optim.SGD(model2.parameters(), lr = 0.0001, momentum=0.9, weight_decay=5e-4)

exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=args.lr_decay_epochs, gamma=0.1)
exp_lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer_ft2, step_size=args.lr_decay_epochs, gamma=0.1)

model = DataParallel(model)
model = model.cuda()
model2 = DataParallel(model2)
model2 = model2.cuda()

def save_network(network1, network2, epoch_label):
    save_filename1 = 'net1_%s.pth'% epoch_label
    save_path1 = os.path.join(args.model_save_dir, save_filename1)
    if not os.path.exists(args.model_save_dir):
        os.mkdir(args.model_save_dir)
    torch.save(network1.state_dict(), save_path1)
    
    save_filename2 = 'net2_%s.pth'% epoch_label
    save_path2 = os.path.join(args.model_save_dir, save_filename2)
    if not os.path.exists(args.model_save_dir):
        os.mkdir(args.model_save_dir)
    torch.save(network2.state_dict(), save_path2)
Exemplo n.º 23
0
class TrainValProcess():
    def __init__(self):
        self.net = ET_Net()
        if (ARGS['weight']):
            self.net.load_state_dict(torch.load(ARGS['weight']))
        else:
            self.net.load_encoder_weight()
        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())

        self.train_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                         part='train')
        self.val_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                       part='val')

        self.optimizer = Adam(self.net.parameters(), lr=ARGS['lr'])
        # Use / to get an approximate result, // to get an accurate result
        total_iters = len(
            self.train_dataset) // ARGS['batch_size'] * ARGS['num_epochs']
        self.lr_scheduler = LambdaLR(
            self.optimizer,
            lr_lambda=lambda iter:
            (1 - iter / total_iters)**ARGS['scheduler_power'])
        self.writer = SummaryWriter()

    def train(self, epoch):

        start = time.time()
        self.net.train()
        train_dataloader = DataLoader(self.train_dataset,
                                      batch_size=ARGS['batch_size'],
                                      shuffle=False)
        epoch_loss = 0.
        for batch_index, items in enumerate(train_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            images = images.float()
            labels = labels.long()
            edges = edges.long()

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            self.optimizer.zero_grad()
            outputs_edge, outputs = self.net(images)
            # print('output edge min:', outputs_edge[0, 1].min(), ' max: ', outputs_edge[0, 1].max())
            # plt.imshow(outputs_edge[0, 1].detach().cpu().numpy() * 255, cmap='gray')
            # plt.show()
            loss_edge = lovasz_softmax(outputs_edge,
                                       edges)  # Lovasz-Softmax loss
            loss_seg = lovasz_softmax(outputs, labels)  #
            loss = ARGS['combine_alpha'] * loss_seg + (
                1 - ARGS['combine_alpha']) * loss_edge
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            n_iter = (epoch - 1) * len(train_dataloader) + batch_index + 1

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6)

            # print('edge min:', edges.min(), ' max: ', edges.max())
            # print('output edge min:', outputs_edge.min(), ' max: ', outputs_edge.max())

            print(
                'Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tL_edge: {:0.4f}\tL_seg: {:0.4f}\tL_all: {:0.4f}\tIoU: {:0.4f}\tLR: {:0.4f}'
                .format(loss_edge.item(),
                        loss_seg.item(),
                        loss.item(),
                        iou.item(),
                        self.optimizer.param_groups[0]['lr'],
                        epoch=epoch,
                        trained_samples=batch_index * ARGS['batch_size'],
                        total_samples=len(train_dataloader.dataset)))

            epoch_loss += loss.item()

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        for name, param in self.net.named_parameters():
            layer, attr = os.path.splitext(name)
            attr = attr[1:]
            self.writer.add_histogram("{}/{}".format(layer, attr), param,
                                      epoch)

        epoch_loss /= len(train_dataloader)
        self.writer.add_scalar('Train/loss', epoch_loss, epoch)
        finish = time.time()

        print('epoch {} training time consumed: {:.2f}s'.format(
            epoch, finish - start))

    def validate(self, epoch):

        start = time.time()
        self.net.eval()
        val_batch_size = min(ARGS['batch_size'], len(self.val_dataset))
        val_dataloader = DataLoader(self.val_dataset,
                                    batch_size=val_batch_size)
        epoch_loss = 0.
        for batch_index, items in enumerate(val_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            # print('label min:', labels[0].min(), ' max: ', labels[0].max())
            # print('edge min:', labels[0].min(), ' max: ', labels[0].max())

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            print('image shape:', images.size())

            with torch.no_grad():
                outputs_edge, outputs = self.net(images)
                loss_edge = lovasz_softmax(outputs_edge,
                                           edges)  # Lovasz-Softmax loss
                loss_seg = lovasz_softmax(outputs, labels)  #
                loss = ARGS['combine_alpha'] * loss_seg + (
                    1 - ARGS['combine_alpha']) * loss_edge

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6)

            print(
                'Validating Epoch: {epoch} [{val_samples}/{total_samples}]\tLoss: {:0.4f}\tIoU: {:0.4f}'
                .format(loss.item(),
                        iou.item(),
                        epoch=epoch,
                        val_samples=batch_index * val_batch_size,
                        total_samples=len(val_dataloader.dataset)))

            epoch_loss += loss

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        epoch_loss /= len(val_dataloader)
        self.writer.add_scalar('Val/loss', epoch_loss, epoch)

        finish = time.time()

        print('epoch {} training time consumed: {:.2f}s'.format(
            epoch, finish - start))

    def train_val(self):
        print('Begin training and validating:')
        for epoch in range(ARGS['num_epochs']):
            self.train(epoch)
            self.validate(epoch)
            self.net.state_dict()
            print(f'Finish training and validating epoch #{epoch+1}')
            if (epoch + 1) % ARGS['epoch_save'] == 0:
                os.makedirs(ARGS['weight_save_folder'], exist_ok=True)
                torch.save(
                    self.net.state_dict(),
                    os.path.join(ARGS['weight_save_folder'],
                                 f'epoch_{epoch+1}.pth'))
                print(f'Model saved for epoch #{epoch+1}.')
        print('Finish training and validating.')