Example #1
0
def main(args):
    # create model
    if args.model == 'dicenet':
        from model.classification import dicenet as net
        model = net.CNNModel(args)
    elif args.model == 'espnetv2':
        from model.classification import espnetv2 as net
        model = net.EESPNet(args)
    elif args.model == 'shufflenetv2':
        from model.classification import shufflenetv2 as net
        model = net.CNNModel(args)
    else:
        NotImplementedError('Model {} not yet implemented'.format(args.model))
        exit()

    num_params = model_parameters(model)
    flops = compute_flops(model)
    print_info_message('FLOPs: {:.2f} million'.format(flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    if not args.weights:
        print_info_message(
            'Grabbing location of the ImageNet weights from the weight dictionary'
        )
        from model.weight_locations.classification import model_weight_map

        weight_file_key = '{}_{}'.format(args.model, args.s)
        assert weight_file_key in model_weight_map.keys(
        ), '{} does not exist'.format(weight_file_key)
        args.weights = model_weight_map[weight_file_key]

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus >= 1 else 'cpu'
    weight_dict = torch.load(args.weights, map_location=torch.device(device))
    model.load_state_dict(weight_dict)

    if num_gpus >= 1:
        model = torch.nn.DataParallel(model)
        model = model.cuda()
        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # Data loading code
    val_loader = loader(args)
    validate(val_loader, model, criteria=None, device=device)
Example #2
0
            print_error_message(
                'Select image size from 512x256, 1024x512, 2048x1024')
        print_log_message('Using scale = ({}, {})'.format(
            args.scale[0], args.scale[1]))
    elif args.dataset == 'greenhouse':
        args.scale = (0.5, 2.0)
    else:
        print_error_message('{} dataset not yet supported'.format(
            args.dataset))

    if not args.finetune:
        from model.weight_locations.classification import model_weight_map

        if args.model == 'espdnet':
            weight_file_key = '{}_{}'.format('espnetv2', args.s)
            assert weight_file_key in model_weight_map.keys(
            ), '{} does not exist'.format(weight_file_key)
            args.weights = model_weight_map[weight_file_key]

            args.depth_weights = './results_segmentation/model_espnetv2_greenhouse/s_2.0_sch_hybrid_loss_ce_res_480_sc_0.5_2.0_autoenc/20200323-073331/espnetv2_2.0_480_checkpoint.pth.tar'
        else:
            weight_file_key = '{}_{}'.format(args.model, args.s)
            assert weight_file_key in model_weight_map.keys(
            ), '{} does not exist'.format(weight_file_key)
            args.weights = model_weight_map[weight_file_key]
    else:
        args.weights = ''
        assert os.path.isfile(
            args.finetune), '{} weight file does not exist'.format(
                args.finetune)

    assert len(args.crop_size) == 2, 'crop-size argument must contain 2 values'
Example #3
0
def main(args):
    # create model
    if args.model == 'dicenet':
        from model.classification import dicenet as net
        model = net.CNNModel(args)
    elif args.model == 'espnetv2':
        from model.classification import espnetv2 as net
        model = net.EESPNet(args)
    elif args.model == 'shufflenetv2':
        from model.classification import shufflenetv2 as net
        model = net.CNNModel(args)
    else:
        NotImplementedError('Model {} not yet implemented'.format(args.model))
        exit()

    num_params = model_parameters(model)
    flops = compute_flops(model)
    print_info_message('FLOPs: {:.2f} million'.format(flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    if not args.weights:
        print_info_message(
            'Grabbing location of the ImageNet weights from the weight dictionary'
        )
        from model.weight_locations.classification import model_weight_map

        weight_file_key = '{}_{}'.format(args.model, args.s)
        assert weight_file_key in model_weight_map.keys(
        ), '{} does not exist'.format(weight_file_key)
        args.weights = model_weight_map[weight_file_key]

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus >= 1 else 'cpu'
    weight_dict = torch.load(args.weights, map_location=torch.device(device))
    model.load_state_dict(weight_dict)

    if num_gpus >= 1:
        model = torch.nn.DataParallel(model)
        model = model.cuda()
        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # Data loading code
    def load_npy(npy_path):
        try:
            npdata = np.load(npy_path).item()
        except:
            npdata = np.load(npy_path)
        return npdata

    def loadData(data_path):
        npy_data = load_npy(data_path)
        signals = npy_data['signals']
        gts = npy_data['gts']
        return signals, gts

    ht_img_width, ht_img_height = args.inpSize, args.inpSize
    ht_batch_size = 5
    signal_length = args.channels
    signals_val, gts_val = loadData('./data_train/fps7_sample10_2D_val_96.npy')
    from data_loader.classification.heart import HeartDataGenerator
    heart_val_data = HeartDataGenerator(signals_val, gts_val, ht_batch_size)
    val_loader = torch.utils.data.DataLoader(heart_val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)
    validate(val_loader, model, criteria=None, device=device)
Example #4
0
def main(args):

    dataset_cityscapes = CityscapesSegmentation(
        root='./vision_datasets/cityscapes',
        train=False,
        coarse=False,
        scale=1.0,
        size=(224, 224))
    dataset_camvid = CamVidSegmentation(root='./vision_datasets/camvid',
                                        list_name='train_camvid.txt',
                                        train=True,
                                        scale=1.0,
                                        size=(224, 224))
    dataset_greenhouse = GreenhouseRGBDSegmentation(
        root='./vision_datasets/greenhouse',
        list_name='train_greenhouse_no_gamma.lst',
        train=True,
        use_depth=False,
        scale=1.0,
        size=(224, 224))
    dataset_greenhouse2 = GreenhouseRGBDSegmentation(
        root='./vision_datasets/greenhouse',
        list_name='val_greenhouse2.lst',
        train=True,
        use_depth=False,
        scale=1.0,
        size=(224, 224))
    #    dataset_ishihara = IshiharaRGBDSegmentation(
    #        root='./vision_datasets/ishihara_rgbd', list_name='ishihara_rgbd_val.txt', train=False,
    #        scale=1.0, size=(224,224), use_depth=False)
    dataset_gta5 = GTA5(root='./vision_datasets/gta5',
                        list_name='val_small.lst',
                        scale=1.0,
                        size=(224, 224))
    dataset_forest = FreiburgForestDataset(train=True,
                                           size=(224, 224),
                                           scale=1.0,
                                           normalize=True)

    model = net.EESPNet(args)

    if not args.weights:
        print_info_message(
            'Grabbing location of the ImageNet weights from the weight dictionary'
        )
        from model.weight_locations.classification import model_weight_map

        weight_file_key = '{}_{}'.format(args.model, args.s)
        assert weight_file_key in model_weight_map.keys(
        ), '{} does not exist'.format(weight_file_key)
        args.weights = model_weight_map[weight_file_key]

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus >= 1 else 'cpu'
    weight_dict = torch.load(args.weights, map_location=torch.device(device))
    model.load_state_dict(weight_dict)

    model = model.cuda()
    #    if num_gpus >= 1:
    #        model = torch.nn.DataParallel(model)
    #        model = model.cuda()
    #        if torch.backends.cudnn.is_available():
    #            import torch.backends.cudnn as cudnn
    #            cudnn.benchmark = True
    #            cudnn.deterministic = True

    features_gta5, labels_gta5 = get_features(model, dataset_gta5, 0)
    features_cityscapes, labels_cityscapes = get_features(
        model, dataset_cityscapes, 1)
    features_camvid, labels_camvid = get_features(model, dataset_camvid, 2)
    features_greenhouse, labels_greenhouse = get_features(
        model, dataset_greenhouse, 3)
    features_greenhouse2, labels_greenhouse2 = get_features(
        model, dataset_greenhouse2, 4)
    features_forest, labels_forest = get_features(model, dataset_forest, 5)

    #    features_ishihara, labels_ishihara = get_features(model, dataset_ishihara, 0)

    #    features = np.concatenate([features_cityscapes, features_camvid, features_greenhouse])
    features = np.concatenate([
        features_cityscapes, features_greenhouse2, features_camvid,
        features_greenhouse, features_gta5, features_forest
    ])
    #    features = np.concatenate([features_camvid, features_greenhouse])
    #    target = np.concatenate([labels_cityscapes, labels_camvid, labels_greenhouse])
    target = np.concatenate([
        labels_cityscapes, labels_greenhouse2, labels_camvid,
        labels_greenhouse, labels_gta5, labels_forest
    ])
    #    target = np.concatenate([labels_camvid, labels_greenhouse])

    from sklearn.manifold import TSNE
    tsne = TSNE(n_components=2, random_state=0, perplexity=20, n_iter=1000)
    feature_embedded = tsne.fit_transform(features)

    plt.scatter(feature_embedded[:, 0],
                feature_embedded[:, 1],
                c=target,
                cmap='jet')
    plt.colorbar()
    plt.savefig('figure.png')