Example #1
0
def get(args):
    """ Entry point. Call this function to get all Charades dataloaders """
    normalize = arraytransforms.Normalize(mean=[0.502], std=[1.0])
    train_file = args.train_file
    val_file = args.val_file
    train_dataset = Charadesflow(args.data,
                                 'train',
                                 train_file,
                                 args.cache,
                                 transform=transforms.Compose([
                                     arraytransforms.RandomResizedCrop(224),
                                     arraytransforms.ToTensor(),
                                     normalize,
                                     transforms.Lambda(lambda x: torch.cat(x)),
                                 ]))
    val_transforms = transforms.Compose([
        arraytransforms.Resize(256),
        arraytransforms.CenterCrop(224),
        arraytransforms.ToTensor(),
        normalize,
        transforms.Lambda(lambda x: torch.cat(x)),
    ])
    val_dataset = Charadesflow(args.data,
                               'val',
                               val_file,
                               args.cache,
                               transform=val_transforms)
    valvideo_dataset = Charadesflow(args.data,
                                    'val_video',
                                    val_file,
                                    args.cache,
                                    transform=val_transforms)
    return train_dataset, val_dataset, valvideo_dataset
def train_transform(rgb, depth):
    s = np.random.uniform(1.0, 1.5)  # random scaling
    # print("scale factor s={}".format(s))
    depth_np = depth / s
    angle = np.random.uniform(-5.0, 5.0)  # random rotation degrees
    do_flip = np.random.uniform(0.0, 1.0) < 0.5  # random horizontal flip

    # perform 1st part of data augmentation
    transform = transforms.Compose([
        transforms.Resize(
            250.0 / iheight
        ),  # this is for computational efficiency, since rotation is very slow
        transforms.Rotate(angle),
        transforms.Resize(s),
        transforms.CenterCrop((oheight, owidth)),
        transforms.HorizontalFlip(do_flip)
    ])
    rgb_np = transform(rgb)

    # random color jittering
    rgb_np = color_jitter(rgb_np)

    rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    depth_np = transform(depth_np)

    return rgb_np, depth_np
Example #3
0
def train_transform(rgb, depth):
    s = np.random.uniform(1.0, 1.5)  # random scaling
    # print("scale factor s={}".format(s))
    depth_np = depth / s
    angle = np.random.uniform(-5.0, 5.0)  # random rotation degrees
    do_flip = np.random.uniform(0.0, 1.0) < 0.5  # random horizontal flip

    # perform 1st part of data augmentation
    transform = transforms.Compose([
        transforms.Crop(130, 10, 240, 1200),
        transforms.Rotate(angle),
        transforms.Resize(s),
        transforms.CenterCrop((oheight, owidth)),
        transforms.HorizontalFlip(do_flip)
    ])
    rgb_np = transform(rgb)

    # random color jittering
    rgb_np = color_jitter(rgb_np)

    rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    # Scipy affine_transform produced RuntimeError when the depth map was
    # given as a 'numpy.ndarray'
    depth_np = np.asfarray(depth_np, dtype='float32')
    depth_np = transform(depth_np)

    return rgb_np, depth_np
Example #4
0
def train_transform(rgb, depth):
    s = np.random.uniform(1.0, 1.5)  # random scaling
    # print("scale factor s={}".format(s))
    depth_np = depth / s
    angle = np.random.uniform(-5.0, 5.0)  # random rotation degrees
    do_flip = np.random.uniform(0.0, 1.0) < 0.5  # random horizontal flip

    # set zeros in depth as NaN
    depth_np[depth_np == 0] = np.nan

    # perform 1st part of data augmentation
    transform = transforms.Compose([
        transforms.Resize(
            float(image_size) / iheight
        ),  # this is for computational efficiency, since rotation is very slow
        transforms.Rotate(angle),
        transforms.Resize(s),
        transforms.CenterCrop((oheight, owidth)),
        transforms.HorizontalFlip(do_flip),
    ])
    rgb_np = transform(rgb)

    # random color jittering
    rgb_np = color_jitter(rgb_np)

    rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    rgb_np = normalize(rgb_np)  # from [0,1] to [-1,1]

    depth_np = transform(depth_np)
    depth_np[np.isnan(depth_np)] = 0

    depth_np = depth_np / 10.0

    return rgb_np, depth_np
def singleImageResult():
    cut_size = 90
    # Data
    transform_test = transforms.Compose([
        transforms.CenterCrop(cut_size),
        transforms.ToTensor(),
    ])
    raw_img = io.imread('images/2.jpg')
    gray = rgb2gray(raw_img)
    gray = resize(gray, (96, 96), mode='symmetric').astype(np.uint8)
    img = gray[:, :, np.newaxis]
    img = np.concatenate((img, img, img), axis=2)
    img = Image.fromarray(img)
    inputs = transform_test(img)
    # intputs = inputs.cpu()
    inputs = inputs[np.newaxis, :, :, :]

    ncrops, c, h, w = np.shape(inputs)
    inputs = inputs.view(-1, c, h, w)
    inputs = inputs.cuda()
    inputs = Variable(inputs, volatile=True)

    outputs = net(inputs)

    score = F.softmax(outputs)
    print(outputs)
Example #6
0
def loading_data():
    mean_std = cfg.DATA.MEAN_STD
    train_simul_transform = own_transforms.Compose([
        own_transforms.Scale(int(cfg.TRAIN.IMG_SIZE[0] / 0.875)),
        own_transforms.RandomCrop(cfg.TRAIN.IMG_SIZE),
        own_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = own_transforms.Compose([
        own_transforms.Scale(int(cfg.TRAIN.IMG_SIZE[0] / 0.875)),
        own_transforms.CenterCrop(cfg.TRAIN.IMG_SIZE)
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        own_transforms.MaskToTensor(),
        own_transforms.ChangeLabel(cfg.DATA.IGNORE_LABEL, cfg.DATA.NUM_CLASSES - 1)
    ])
    restore_transform = standard_transforms.Compose([
        own_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train', simul_transform=train_simul_transform, transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=16, shuffle=True)
    val_set = CityScapes('val', simul_transform=val_simul_transform, transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=cfg.VAL.BATCH_SIZE, num_workers=16, shuffle=False)

    return train_loader, val_loader, restore_transform
    def train_transform(self, rgb: np.ndarray, depth_raw: np.ndarray,
                        depth_fix: np.ndarray) -> TNpData:
        s = np.random.uniform(1.0, 1.5)  # random scaling
        depth_raw = depth_raw / s
        depth_fix = depth_fix / s
        angle = np.random.uniform(-5.0, 5.0)  # random rotation degrees
        do_flip = np.random.uniform(0.0, 1.0) < 0.5  # random horizontal flip
        # perform 1st part of data augmentation
        transform = transforms.Compose([
            transforms.Resize(
                250.0 / self.iheight
            ),  # this is for computational efficiency, since rotation is very slow
            transforms.Rotate(angle),
            transforms.Resize(s),
            transforms.CenterCrop((self.oheight, self.owidth)),
            transforms.HorizontalFlip(do_flip)
        ])
        rgb = transform(rgb)

        # random color jittering
        rgb = color_jitter(rgb)

        rgb = np.asfarray(rgb, dtype='float') / 255
        depth_raw = transform(depth_raw)
        depth_fix = transform(depth_fix)

        return rgb, depth_raw, depth_fix
def get(args):
    """ Entry point. Call this function to get all Charades dataloaders """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_file = args.train_file
    val_file = args.val_file
    train_dataset = Charades(
        args.data,
        'train',
        train_file,
        args.cache,
        transform=transforms.Compose([
            transforms.RandomResizedCrop(args.inputsize),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),  # missing PCA lighting jitter
            normalize,
        ]))
    val_dataset = Charades(args.data,
                           'val',
                           val_file,
                           args.cache,
                           transform=transforms.Compose([
                               transforms.Resize(
                                   int(256. / 224 * args.inputsize)),
                               transforms.CenterCrop(args.inputsize),
                               transforms.ToTensor(),
                               normalize,
                           ]))
    valvideo_dataset = Charades(args.data,
                                'val_video',
                                val_file,
                                args.cache,
                                transform=transforms.Compose([
                                    transforms.Resize(
                                        int(256. / 224 * args.inputsize)),
                                    transforms.CenterCrop(args.inputsize),
                                    transforms.ToTensor(),
                                    normalize,
                                ]))
    return train_dataset, val_dataset, valvideo_dataset
 def val_transform(self, rgb: np.ndarray, depth_raw: np.ndarray,
                   depth_fix: np.ndarray) -> TNpData:
     # perform 1st part of data augmentation
     transform = transforms.Compose([
         transforms.Resize(240.0 / self.iheight),
         transforms.CenterCrop((self.oheight, self.owidth)),
     ])
     rgb = transform(rgb)
     rgb = np.asfarray(rgb, dtype='float') / 255
     depth_raw = transform(depth_raw)
     depth_fix = transform(depth_fix)
     return rgb, depth_raw, depth_fix
Example #10
0
def imageNet_loader(train_size, valid_size, test_size, crop_size):
    # http://blog.outcome.io/pytorch-quick-start-classifying-an-image/
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        './data/kaggle/train',
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=train_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               drop_last=True)

    valid_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        './data/kaggle/valid',
        transforms.Compose([
            transforms.CenterCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=valid_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               drop_last=True)

    test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        './data/image_net/small_classes/',
        transforms.Compose([
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
        ])),
                                              batch_size=test_size,
                                              shuffle=False)
    return train_loader, valid_loader, test_loader
def val_transform(rgb, depth):
    depth_np = depth

    # perform 1st part of data augmentation
    transform = transforms.Compose([
        transforms.Resize(240.0 / iheight),
        transforms.CenterCrop((oheight, owidth)),
    ])
    rgb_np = transform(rgb)
    rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    depth_np = transform(depth_np)

    return rgb_np, depth_np
Example #12
0
def open_nii(volpaths, segpaths, ind, num, in_z, out_z, center_crop_sz,\
        series_names, seg_series_names, txforms=None,nrrd=True):
    vols = []
    segs = []
    if nrrd:
        series, seg_series = get_nii_nrrd(volpaths, segpaths)
    assert np.shape(series)[3] == np.shape(seg_series)[3]
    num_slices = np.arange(np.shape(series[0])[2])
    if in_z != 0:
        num_slices = num_slices[in_z:-in_z]
    sub_rand = np.random.choice(num_slices, size=num, replace=False)

    center = transforms.CenterCrop(center_crop_sz)
    depth_center = transforms.DepthCenterCrop(out_z)
    series = [vol.astype(np.float) for vol in series]
    for i in sub_rand:
        if in_z == 0:
            nascent_series = [vol[:, :, i] for vol in series]
            nascent_seg_series = [seg[:, :, i] for seg in seg_series]
            nascent_series = np.expand_dims(nascent_series, axis=0)
            nascent_seg_series = np.expand_dims(nascent_seg_series, axis=0)
        else:
            nascent_series = [
                vol[:, :, i - in_z:i + 1 + in_z] for vol in series
            ]
            assert nascent_series[0].shape[2] == in_z * 2 + 1
            nascent_series = [np.squeeze(np.split(v,\
                    v.shape[2], axis=2)) for v in nascent_series]

            nascent_seg_series = [seg[:,:,i-in_z:i+1+in_z] for seg in \
                    seg_series]
            nascent_seg_series = [depth_center.engage(s) for s in \
                    nascent_seg_series]
            nascent_seg_series = [np.squeeze(np.split(s,\
                    s.shape[2], axis=2)) for s in nascent_seg_series]

            if out_z == 1:
                nascent_seg_series = \
                        np.expand_dims(nascent_seg_series, axis=0)

        if txforms is not None:
            for j in txforms:
                nascent_series, nascent_seg_series = \
                        j.engage(nascent_series, nascent_seg_series)

            vols.append(np.squeeze(nascent_series))

            segs.append(np.squeeze(center.engage(nascent_seg_series, \
                    out_z > 1)))

    return vols, segs
Example #13
0
    def _make_test_transform(self, crop_type, crop_size_img, crop_size_label,
                             pad_size):
        test_transform_ops = self.basic_transform_ops.copy()
        if pad_size is not None:
            test_transform_ops.append(transforms.Pad(pad_size, 0))
        if crop_type == 'center':
            test_transform_ops.append(
                transforms.CenterCrop(crop_size_img, crop_size_label))
        elif crop_type is None:
            pass
        else:
            raise RuntimeError('Unknown test crop type.')

        return transforms.Compose(test_transform_ops)
Example #14
0
def val_transform(rgb, depth):
    # perform 1st part of data augmentation
    transform = transforms.Compose([
        transforms.Resize(float(image_size) / iheight),
        transforms.CenterCrop((oheight, owidth)),
    ])
    rgb_np = transform(rgb)
    rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    rgb_np = normalize(rgb_np)  # from [0,1] to [-1,1]

    depth_np = transform(depth)
    depth_np = depth_np / 10.0

    return rgb_np, depth_np
def test(valdir, bs, sz, rect_val=False):
    if rect_val:
        idx_ar_sorted = sort_ar(valdir)
        idx_sorted, _ = zip(*idx_ar_sorted)
        idx2ar = map_idx2ar(idx_ar_sorted, bs)

        ar_tfms = [transforms.Resize(int(sz * 1.14)), CropArTfm(idx2ar, sz)]
        val_dataset = ValDataset(valdir, transform=ar_tfms)
        return PaddleDataLoader(val_dataset,
                                concurrent=1,
                                indices=idx_sorted,
                                shuffle=False).reader()

    val_tfms = [transforms.Resize(int(sz * 1.14)), transforms.CenterCrop(sz)]
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))

    return PaddleDataLoader(val_dataset).reader()
Example #16
0
def val_transform(rgb, depth):
    depth_np = depth

    # perform 1st part of data augmentation
    transform = transforms.Compose([
        #transforms.Resize(528.0 / iheight),
        transforms.Resize(240.0 / iheight),
        transforms.CenterCrop((oheight, owidth)),
    ])
    rgb_np = transform(rgb)
    # 自己添加
    rgb_np = cv2.resize(rgb_np, (512, 512), interpolation=cv2.INTER_NEAREST)
    ##########
    rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    depth_np = transform(depth_np)
    # 自己添加
    depth_np = cv2.resize(depth_np, (512, 512),
                          interpolation=cv2.INTER_NEAREST)
    ###########
    return rgb_np, depth_np
Example #17
0
def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]):
    center_resize = 600
    Normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    data_transforms = {
        'swap':
        transforms.Compose([
            transforms.Randomswap((swap_num[0], swap_num[1])),
        ]),
        'common_aug':
        transforms.Compose([
            transforms.Resize((resize_reso, resize_reso)),
            transforms.RandomRotation(degrees=15),
            transforms.RandomCrop((crop_reso, crop_reso)),
            transforms.RandomHorizontalFlip(),
        ]),
        'train_totensor':
        transforms.Compose([
            transforms.Resize((crop_reso, crop_reso)),
            # ImageNetPolicy(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'val_totensor':
        transforms.Compose([
            transforms.Resize((crop_reso, crop_reso)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'test_totensor':
        transforms.Compose([
            transforms.Resize((resize_reso, resize_reso)),
            transforms.CenterCrop((crop_reso, crop_reso)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'None':
        None,
    }
    return data_transforms
Example #18
0
    def _make_train_transform(self, crop_type, crop_size_img, crop_size_label,
                              rand_flip, mod_drop_rate, balance_rate, pad_size,
                              rand_rot90, random_black_patch_size,
                              mini_positive):
        train_transform_ops = self.basic_transform_ops.copy()

        train_transform_ops += [
            transforms.RandomBlack(random_black_patch_size),
            transforms.RandomDropout(mod_drop_rate),
            transforms.RandomFlip(rand_flip)
        ]
        if pad_size is not None:
            train_transform_ops.append(transforms.Pad(pad_size, 0))

        if rand_rot90:
            train_transform_ops.append(transforms.RandomRotate2d())

        if crop_type == 'random':
            if mini_positive:
                train_transform_ops.append(
                    transforms.RandomCropMinSize(crop_size_img, mini_positive))
            else:
                train_transform_ops.append(
                    transforms.RandomCrop(crop_size_img))
        elif crop_type == 'balance':
            train_transform_ops.append(
                transforms.BalanceCrop(balance_rate, crop_size_img,
                                       crop_size_label))
        elif crop_type == 'center':
            train_transform_ops.append(
                transforms.CenterCrop(crop_size_img, crop_size_label))
        elif crop_type is None:
            pass
        else:
            raise RuntimeError('Unknown train crop type.')

        return transforms.Compose(train_transform_ops)
Example #19
0
def train_transform(rgb, depth):
    s = np.random.uniform(1.0, 1.5)  # random scaling
    # print("scale factor s={}".format(s))
    depth_np = depth / s
    angle = np.random.uniform(-5.0, 5.0)  # random rotation degrees
    do_flip = np.random.uniform(0.0, 1.0) < 0.5  # random horizontal flip

    # perform 1st part of data augmentation
    transform = transforms.Compose([
        #transforms.Resize(530 / iheight), # this is for computational efficiency, since rotation is very slow
        transforms.Resize(250 / iheight),
        transforms.Rotate(angle),
        transforms.Resize(s),
        transforms.CenterCrop((oheight, owidth)),
        transforms.HorizontalFlip(do_flip)
    ])

    rgb_np = transform(rgb)
    # 自己添加
    # rgb_np = Transform.resize(rgb_np, [512, 512])
    rgb_np = cv2.resize(rgb_np, (512, 512), interpolation=cv2.INTER_NEAREST)
    ###########
    # random color jittering
    rgb_np = color_jitter(rgb_np)

    rgb_np = np.asfarray(rgb_np, dtype='float') / 255
    depth_np = transform(depth_np)

    #自己添加
    depth_np = cv2.resize(depth_np, (512, 512),
                          interpolation=cv2.INTER_NEAREST)
    #depth_np=Transform.resize(depth_np,[512,512])
    ###########
    #data=rgb_np*255
    #data=Image.fromarray(data.astype(np.uint8))
    #data.show()
    return rgb_np, depth_np
Example #20
0
pred = 8
top_3 = [9, 8, 7]
out = np.zeros(10)
# Load model
print('Loading model...')

curr_folder = 'models_jester'
model = FullModel(batch_size=1, seq_lenght=16)
loaded_dict = torch.load(curr_folder + '/demo.ckp')
model.load_state_dict(loaded_dict)
model = model.cuda()
model.eval()

std, mean = [0.2674, 0.2676, 0.2648], [0.4377, 0.4047, 0.3925]
transform = Compose([
    t.CenterCrop((96, 96)),
    t.ToTensor(),
    t.Normalize(std=std, mean=mean),
])

print('Starting prediction')

s = time.time()
n = 0
hist = []
mean_hist = []
setup = True
plt.ion()
fig, ax = plt.subplots()
cooldown = 0
eval_samples = 2
Example #21
0
def test_net_cheap(test_volpath, test_segpath, mult_inds, in_z, model,\
        t_transform_plan, orig_dim, batch_size, out_file, num_labels,\
        num_labels_final, volpaths, segpaths,\
        nrrd=True, vol_only=False, get_dice=False, make_niis=False,\
        verbose=True):

    t_out_z, t_center_crop_sz = get_out_size(orig_dim, in_z,\
            t_transform_plan, model)
    t_center = transforms.CenterCrop(t_center_crop_sz)

    dices = []
    jaccards = []
    hds = []
    assds = []
    dice_inds = []
    times = []
    for ind in range(len(mult_inds)):
        t0 = time.time()
        if vol_only:
            series, seg_series = open_double_vol(volpaths[ind])
            seg_series = [a*0 for a in seg_series]
        else:
            series, seg_series = preprocess.get_nii_nrrd(volpaths[ind],\
                    segpaths[ind])
        num_slices = np.arange(np.shape(series[0])[2])
        if in_z == 0:
            num_slices = num_slices
        else:
            num_slices = num_slices[in_z:-in_z]

        slice_inds = num_slices
        for slice_ind in slice_inds:
            assert slice_ind >= np.min(num_slices)\
                    and slice_ind <= np.max(num_slices)

        tout, tvol, tseg = get_subvols_cheap(series, seg_series, slice_inds,\
                in_z, t_out_z, t_center_crop_sz, model, num_labels,\
                batch_size, t_transform_plan, verbose=verbose)
        duration = time.time() - t0
        tseg = np.clip(tseg, 0,2)
        times.append(duration)


        if get_dice:
            hd, assd = get_dists_non_volumetric(tseg.astype(np.int64),\
                    np.argmax(tout, axis=0))
            tseg_hot = get_hot(tseg, num_labels_final)
            tout_hot = np.argmax(tout,axis=0)
            tout_hot = np.clip(tout_hot, 0,1)
            tout_hot = get_hot(tout_hot, num_labels_final)
            dce = dice(tseg_hot[1:],tout_hot[1:])
            jc = jaccard(tseg_hot[1:], tout_hot[1:])

            if verbose:
                print(('\r{}: Duration: {:.2f} ; Dice: {:.2f} ; Jaccard: {:.2f}' +\
                        ' ; Hausdorff: {:.2f} ; ASSD: {:.2f}').format(\
                        mult_inds[ind], duration, dce, jc, np.mean(hd),\
                        np.mean(assd)))
            jaccards.append(jc)
            dices.append(dce)
            hds.append(hd)
            assds.append(assd)
            dice_inds.append(mult_inds[ind])
        else:
            if verbose:
                print('\r{}'.format(mult_inds[ind]))


        com_fake = []
        com_real = []

        if make_niis:
            # out_out = tout
            out_out = np.zeros_like(tout[0])
            maxes = np.argmax(tout, axis=0)
            sparse_maxes = sparsify(maxes)

            for i in range(sparse_maxes.shape[1]):
                lw1, num1 = measurements.label(sparse_maxes[1,i])
                area1 = measurements.sum(sparse_maxes[1,i],lw1,\
                        index=np.arange(lw1.max() + 1))
                areaImg1 = area1[lw1]
                sparse_maxes[1,i] = np.where(areaImg1 < np.max(areaImg1), 0, 1)
                com_lateral = list(measurements.center_of_mass(sparse_maxes[1,i]))

                lw2, num2 = measurements.label(sparse_maxes[2,i])
                area2 = measurements.sum(sparse_maxes[2,i],lw2,\
                        index=np.arange(lw2.max() + 1))
                areaImg2 = area2[lw2]
                sparse_maxes[2,i] = np.where(areaImg2 < np.max(areaImg2), 0, 1)
                com_septal = list(measurements.center_of_mass(sparse_maxes[2,i]))
                com_fake.append(com_lateral + com_septal)

            maxes = np.argmax(sparse_maxes, axis=0)
            out_out = np.flip(maxes, -1)
            out_out = np.rot90(out_out, k=-1, axes=(-2,-1))
            out_out = np.transpose(out_out,[1,2,0])
            write_nrrd(out_out.astype(np.uint8), \
                    out_file + '/tout-{}.seg.nrrd'.format(\
                    mult_inds[ind]))

            seg_out = tseg
            sparse_seg = sparsify(tseg.astype(np.uint8))
            for i in range(sparse_seg.shape[1]):
                com_lateral_seg = list(measurements.center_of_mass(sparse_seg[1,i]))
                com_septal_seg = list(measurements.center_of_mass(sparse_seg[2,i]))
                com_real.append(com_lateral_seg + com_septal_seg)

            seg_out = np.flip(seg_out, -1)
            seg_out = np.rot90(seg_out, k=-1, axes=(-2,-1))
            seg_out = np.transpose(seg_out,[1,2,0])
            write_nrrd(seg_out.astype(np.uint8), \
                    out_file + '/tseg-{}.seg.nrrd'.format(\
                    mult_inds[ind]))

            tv = np.stack(t_center.engage(np.expand_dims(tvol, 0),True))
            vol_out = tv
            vol_out = np.flip(vol_out, -1)
            vol_out = np.rot90(vol_out, k=-1, axes=(-2,-1))
            vol_out = np.transpose(vol_out,[1,2,0])
            vol_out = nib.Nifti1Image(vol_out, np.eye(4))
            nib.save(vol_out, \
                    out_file + '/tvol-{}.nii'.format(\
                    mult_inds[ind]))
        # print('Jaccard summary: ' + str(get_CI(jaccards)))
            [a.extend(b) for a, b in zip(com_real, com_fake)]
            com_merged = com_real
            com_merged = [[round(b, 2) for b in a] for a in com_merged]
            headers = ['real_y_l', 'real_x_l', 'real_y_s', 'real_x_s', \
                    'fake_y_l', 'fake_x_l', 'fake_y_s', 'fake_x_s']
            df = pd.DataFrame(com_merged, columns=headers)
            df.to_csv(out_file + '/{}.csv'.format(mult_inds[ind]))

    # return vol_out, out_out, seg_out
    if get_dice:
        return np.array(dices), np.array(jaccards), np.array(hds), np.array(assds),\
                np.array(times)
    else:
        return
Example #22
0
def get_subvols_cheap(series, seg_series, slice_inds, in_z, out_z, \
        center_crop_sz, model, num_labels, batch_size, txforms=None,\
        verbose=True):

    # get beginning index of output since the z dim is smaller than vol
    z0 = (in_z*2+1 - out_z)//2

    sz = np.array([num_labels, slice_inds.shape[0]+2*in_z, center_crop_sz,\
            center_crop_sz])

    bigout = np.zeros(sz)
    bigvol = np.zeros(sz[1:])
    bigseg = np.zeros(sz[1:])

    center = transforms.CenterCrop(center_crop_sz)
    depth_center = transforms.DepthCenterCrop(out_z)
    vols = []
    segs = []
    batch_ind = 0
    absolute_ind = 0
    for i in slice_inds:
        if in_z == 0:
            nascent_series = [vol[:,:,i] for vol in series]
            nascent_seg_series = [seg[:,:,i] for seg in seg_series]
            nascent_series = np.expand_dims(nascent_series, axis=0)
            nascent_seg_series = np.expand_dims(nascent_seg_series, axis=0)
        else:
            nascent_series = [v[:,:,i-in_z:i+1+in_z] for v in series]
            assert nascent_series[0].shape[2]==in_z*2+1
            nascent_series = [np.squeeze(np.split(v,\
                    v.shape[2], axis=2)) for v in nascent_series]

            nascent_seg_series = [s[:,:,i-in_z:i+1+in_z] for s in seg_series]
            nascent_seg_series = [depth_center.engage(s) for s in\
                    nascent_seg_series]
            nascent_seg_series = [np.squeeze(np.split(s,\
                    s.shape[2], axis=2)) for s in nascent_seg_series]

            if out_z == 1:
                nascent_seg_series = \
                        np.expand_dims(nascent_seg_series, axis=0)

        if txforms is not None:
            for j in txforms:
                nascent_series, nascent_seg_series = \
                        j.engage(nascent_series, nascent_seg_series)

            vols.append(np.squeeze(nascent_series))

            segs.append(np.squeeze(center.engage(nascent_seg_series, \
                    out_z > 1)))

            absolute_ind += 1

        if (absolute_ind >= batch_size or (i >= slice_inds[-1] and vols)):
            # nascent_vol = np.array(vols).squeeze()
            # nascent_seg = np.array(segs).squeeze()
            nascent_series = np.array(vols)
            nascent_seg_series = np.array(segs)
            nascent_series = preprocess.rot_and_flip(nascent_series)
            nascent_seg_series = preprocess.rot_and_flip(nascent_seg_series)
            nascent_series = nascent_series-np.min(nascent_series)
            nascent_series = nascent_series/np.max(nascent_series)

            if len(nascent_series.shape) < 4:
                nascent_series = np.expand_dims(nascent_series, 0)

            tv = torch.from_numpy(nascent_series).float()
            tv = Variable(tv).cuda()
            # print(i)
            if verbose:
                sys.stdout.write('\r{:.2f}%'.format(i/sz[1]))
                sys.stdout.flush()
            if in_z == 0:
                tv = tv.permute(1,0,2,3)
            tout = model(tv).data.cpu().numpy().astype(np.int8)
            if in_z == 0:
                nascent_series = nascent_series.squeeze()
                if np.array(nascent_series.shape).shape[0] < 3:
                    nascent_series = np.expand_dims(nascent_series, 0)
            for j in range(len(nascent_series)):

                bsz = len(nascent_series)
                beg = i - in_z + z0 - bsz + j + 1
                end = i - in_z + z0 - bsz + j + out_z + 1
                bigout[:,beg:end] += np.expand_dims(tout[j], 1)
                bigseg[beg:end] = nascent_seg_series[j]

                beg = i - in_z + 1 - bsz + j
                end = i + in_z - bsz + j + 2
                bigvol[beg:end] = nascent_series[j]

            absolute_ind = 0
            batch_ind += 1
            vols = []
            segs = []

    return bigout, bigvol, bigseg
Example #23
0
def open_nii(volpaths, segpaths, ind, num, in_z, out_z, center_crop_sz,\
        series_names, seg_series_names, txforms=None,nrrd=True):
    vols = []
    segs = []
    # volpath = os.path.join(volpaths, 'volume-' + str(ind) + '.nii')
    if nrrd:
        # segpath = os.path.join(segpath, 'segmentation-' + str(ind) + '.seg.nrrd')
        series, seg_series = get_nii_nrrd(volpaths, segpaths)
    # else:
        # segpath = os.path.join(segpath, 'segmentation-' + str(ind) + '.nii')
        # vol, seg = get_nii_nii(volpaths, segpaths)
    assert np.shape(series)[3] == np.shape(seg_series)[3]
    num_slices = np.arange(np.shape(series[0])[2])
    if in_z != 0:
        num_slices = num_slices[in_z:-in_z]
    if(num_slices.size <= 2 and in_z != 0):
        print(ind)
        print(num_slices)
        print(num)
    sub_rand = np.random.choice(num_slices, size=num, replace=False)

    center = transforms.CenterCrop(center_crop_sz)
    depth_center = transforms.DepthCenterCrop(out_z)
    series = [vol.astype(np.float) for vol in series]
    for i in sub_rand:
        if in_z == 0:
            nascent_series = [vol[:,:,i] for vol in series]
            nascent_seg_series = [seg[:,:,i] for seg in seg_series]
            nascent_series = np.expand_dims(nascent_series, axis=0)
            nascent_seg_series = np.expand_dims(nascent_seg_series, axis=0)
        else:
            nascent_series = [vol[:,:,i-in_z:i+1+in_z] for vol in series]
            assert nascent_series[0].shape[2] == in_z*2+1
            nascent_series = [np.squeeze(np.split(v,\
                    v.shape[2], axis=2)) for v in nascent_series]

            nascent_seg_series = [seg[:,:,i-in_z:i+1+in_z] for seg in \
                    seg_series]
            nascent_seg_series = [depth_center.engage(s) for s in \
                    nascent_seg_series]
            nascent_seg_series = [np.squeeze(np.split(s,\
                    s.shape[2], axis=2)) for s in nascent_seg_series]

            if out_z == 1:
                nascent_seg_series = \
                        np.expand_dims(nascent_seg_series, axis=0)

        if txforms is not None:
            for j in txforms:
                nascent_series, nascent_seg_series = \
                        j.engage(nascent_series, nascent_seg_series)
                bad = False
                for s in nascent_seg_series:
                    m = np.max(s)
                    if np.mod(m, 1) != 0:
                       bad = True
                if bad == True:
                    print(j)
                bad = False

            vols.append(np.squeeze(nascent_series))

            segs.append(np.squeeze(center.engage(nascent_seg_series, \
                    out_z > 1)))

    return vols, segs
Example #24
0
 def train(self):
     use_cuda = torch.cuda.is_available()
     path = os.path.join('./out_models/' + self.model_name + '_' +
                         self.task_name + '_' + self.job_id)
     ## get logger
     logger = self.get_logger(self.model_name, self.task_name, self.job_id,
                              path)
     logger.info("Job_id : {}".format(self.job_id))
     logger.info("gpus_device_ids : {}".format(self.device_ids))
     logger.info("Task Name : {}".format(self.task_name))
     logger.info("Backbone_name : {}".format(self.model_name))
     logger.info("input_shape : ({},{}.{})".format(self.input_shape[0],
                                                   self.input_shape[1],
                                                   self.input_shape[2]))
     logger.info("batch_size : {}".format(self.batch_size))
     logger.info("num_epochs : {}".format(self.num_epochs))
     logger.info("warmup_steps : {}".format(self.warmup_steps))
     logger.info("resume_from : {}".format(self.resume_from))
     logger.info("pretrained : {}".format(self.pretrained))
     logger.info("mixup : {}".format(self.mixup))
     logger.info("cutmix : {}".format(self.cutmix))
     ## tensorboard writer
     log_dir = os.path.join(path, "{}".format("tensorboard_log"))
     if not os.path.isdir(log_dir):
         os.mkdir(log_dir)
     writer = SummaryWriter(log_dir)
     ## get model of train
     net = get_model(self.model_name)
     net = torch.nn.DataParallel(net, device_ids=self.device_ids)
     net = net.cuda(device=self.device_ids[0])
     ## loss
     criterion = nn.CrossEntropyLoss()
     ## optimizer
     if self.optimizers == 'SGD':
         optimizer = optim.SGD(net.parameters(),
                               lr=self.init_lr,
                               momentum=0.9,
                               weight_decay=self.weight_decay)
     elif self.optimizers == 'Adam':
         optimizer = optim.Adam(net.parameters(),
                                lr=self.init_lr,
                                weight_decay=self.weight_decay)
     milestones = [80, 150, 200, 300]
     scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                milestones=milestones,
                                                gamma=0.1)
     logger.info(("===========opti=========="))
     logger.info("Optimizer:{}".format(self.optimizers))
     logger.info("lr:{}".format(self.init_lr))
     logger.info("weight_decay:{}".format(self.weight_decay))
     logger.info("lr_scheduler: MultiStepLR")
     logger.info("milestones:{}".format(milestones))
     ## augumation
     normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                      std=[0.5, 0.5, 0.5])
     ## train aug
     transform_train = transforms.Compose([
         transforms.RandomCrop(int(self.input_shape[-1])),
         transforms.RandomHorizontalFlip(),
         transforms.RandomBrightness(brightness = self.brightness, brightness_ratio=self.brightness_ratio),
         transforms.RandomBlur(blur_ratio = self.blur_ratio),
         transforms.RandomRotation(degrees = self.degrees, rotation_ratio = 0.1),
         transforms.ColorJitter(brightness = self.color_brightnesss, contrast = self.color_contrast,\
                                saturation = self.color_saturation, hue=0),
         transforms.ToTensor(),
         #normalize,
     ])
     ## test aug
     transform_test = transforms.Compose([
         transforms.CenterCrop(int(self.input_shape[-1])),
         transforms.ToTensor(),
         #normalize,
     ])
     logger.info(("============aug==========="))
     logger.info("crop: RandomCrop")
     logger.info("RandomHorizontalFlip: True")
     logger.info("brightness:{}".format(self.brightness))
     logger.info("brightness_ratio:{}".format(self.brightness_ratio))
     logger.info("blur_ratio:{}".format(self.blur_ratio))
     logger.info("degrees:{}".format(self.degrees))
     logger.info("color_brightnesss:{}".format(self.color_brightnesss))
     logger.info("color_contrast:{}".format(self.color_contrast))
     logger.info("color_saturation:{}".format(self.color_saturation))
     ## prepara data
     print('==> Preparing data..')
     logger.info(("==========Datasets========="))
     logger.info("train_datasets:{}".format(self.train_datasets))
     logger.info("val_datasets:{}".format(self.val_datasets))
     logger.info("test_datasets:{}".format(self.test_datasets))
     #trainset = DataLoader(split = 'Training', transform=transform_train)
     trainset = DataLoader(self.train_datasets,
                           self.val_datasets,
                           self.test_datasets,
                           split='Training',
                           transform=transform_train)
     trainloader = torch.utils.data.DataLoader(trainset,
                                               batch_size=self.batch_size *
                                               len(self.device_ids),
                                               shuffle=True)
     Valset = DataLoader(self.train_datasets,
                         self.val_datasets,
                         self.test_datasets,
                         split='Valing',
                         transform=transform_test)
     Valloader = torch.utils.data.DataLoader(Valset,
                                             batch_size=64 *
                                             len(self.device_ids),
                                             shuffle=False)
     Testset = DataLoader(self.train_datasets,
                          self.val_datasets,
                          self.test_datasets,
                          split='Testing',
                          transform=transform_test)
     Testloader = torch.utils.data.DataLoader(Testset,
                                              batch_size=64 *
                                              len(self.device_ids),
                                              shuffle=False)
     ## train
     logger.info(("======Begain Training======"))
     #self.train_model(net, criterion, optimizer, scheduler, trainloader, Valloader, Testloader, logger, writer, path)
     self.train_model(net, criterion, optimizer, scheduler, trainloader,
                      Valloader, Testloader, logger, writer, path)
     logger.info(("======Finsh Training !!!======"))
     logger.info(("best_val_acc_epoch: %d, best_val_acc: %0.3f" %
                  (self.best_Val_acc_epoch, self.best_Val_acc)))
     logger.info(("best_test_acc_epoch: %d, best_test_acc: %0.3f" %
                  (self.best_Test_acc_epoch, self.best_Test_acc)))
Example #25
0
        df['is_manip'] = 0
        df = df[df['target'].notnull()]
        df['to_rotate'] = 0
        return df

    return None


train_transform = Compose([
    albu_trans.RandomCrop(target_size),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = Compose([
    albu_trans.CenterCrop(target_size),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def add_args(parser):
    arg = parser.add_argument
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=4)
    arg('--n-epochs', type=int, default=30)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--device-ids', type=str, help='For example 0,1 to run on two GPUs')
    arg('--model', type=str)
Example #26
0
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomHorizontalFlip(),
        # transforms.ColorJitter(brightness=.5,contrast=.9,saturation=.5,hue=.1),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    transform_push = transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # train set
    train_dataset = datasets.ImageFolder(train_dir, transform_train)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=train_batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=False)
    # push set
    train_push_dataset = datasets.ImageFolder(train_push_dir, transform_push)
    train_push_loader = torch.utils.data.DataLoader(
        train_push_dataset,
Example #27
0
def main(args):
    if args.apex and amp is None:
        raise RuntimeError(
            "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
            "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    traindir = os.path.join(args.data_path, args.train_dir)
    valdir = os.path.join(args.data_path, args.val_dir)
    normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
                            std=[0.22803, 0.22145, 0.216989])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    transform_train = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        T.Resize((128, 171)),
        T.RandomHorizontalFlip(), normalize,
        T.RandomCrop((112, 112))
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
        dataset.transform = transform_train
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset = torchvision.datasets.Kinetics400(
            traindir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_train,
            frame_rate=15)
        if args.cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset, traindir), cache_path)

    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)

    transform_test = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        T.Resize((128, 171)), normalize,
        T.CenterCrop((112, 112))
    ])

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
        dataset_test.transform = transform_test
    else:
        if args.distributed:
            print("It is recommended to pre-compute the dataset cache "
                  "on a single-gpu first, as it will be faster")
        dataset_test = torchvision.datasets.Kinetics400(
            valdir,
            frames_per_clip=args.clip_len,
            step_between_clips=1,
            transform=transform_test,
            frame_rate=15)
        if args.cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            utils.save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    train_sampler = RandomClipSampler(dataset.video_clips,
                                      args.clips_per_video)
    test_sampler = UniformClipSampler(dataset_test.video_clips,
                                      args.clips_per_video)
    if args.distributed:
        train_sampler = DistributedSampler(train_sampler)
        test_sampler = DistributedSampler(test_sampler)

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   collate_fn=collate_fn)

    print("Creating model")
    model = torchvision.models.video.__dict__[args.model](
        pretrained=args.pretrained)
    model.to(device)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    lr = args.lr * args.world_size
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)

    # convert scheduler to be per iteration, not per epoch, for warmup that lasts
    # between different epochs
    warmup_iters = args.lr_warmup_epochs * len(data_loader)
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     milestones=lr_milestones,
                                     gamma=args.lr_gamma,
                                     warmup_iters=warmup_iters,
                                     warmup_factor=1e-5)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
                        device, epoch, args.print_freq, args.apex)
        evaluate(model, criterion, data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Example #28
0
import utils

cut_size = 44
bs = 1
model_path = 'FER2013_Resnet18/PublicTest_model.t7'

epls = [0.001 * i for i in range(11)]

transform_train = transforms.Compose([
    transforms.RandomCrop(44),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.CenterCrop(cut_size),
    transforms.ToTensor(),
])

transform_eval = transforms.Compose([
    transforms.TenCrop(cut_size),
    transforms.Lambda(lambda crops: torch.stack(
        [transforms.ToTensor()(crop) for crop in crops])),
])
tbs = 32
trainset = FER2013(split='Training', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=tbs,
                                          shuffle=True,
                                          num_workers=1)
PublicTestset = FER2013(split='PublicTest', transform=transform_eval)
Example #29
0
from PIL import Image
import transforms

size = 64
channel = 3
max_no = 202599
img_key = 'img_raw'
file_tpl = '%6d.jpg'
home_dir = os.path.expanduser('~')
celeb_source = os.path.join(home_dir, "Pictures/img_align_celeba")

default_attribs = {img_key: tf.FixedLenFeature([], tf.string)}

default_transf = transforms.Compose([
    transforms.Scale(size),
    transforms.CenterCrop(size),
    transforms.ToFloat(),
    transforms.Normalize(0.5, 0.5)
])


def process_celebA(dest='celebA',
                   celeb_source=celeb_source,
                   force=False,
                   transform=default_transf,
                   files=None):
    dest_file = '%s.tfr' % dest
    if os.path.exists(dest_file) and not force:
        return dest_file

    print 'Processing celeb data into a Tensorflow Record file. It may take a while depending on your computer speed...'
def feature_extractor():
    # loading net
    with tf.variable_scope('RGB'):
        net = i3d.InceptionI3d(400, spatial_squeeze=True, final_endpoint='Logits')
    rgb_input = tf.placeholder(tf.float32, shape=(batch_size, _SAMPLE_VIDEO_FRAMES, _IMAGE_SIZE, _IMAGE_SIZE, 3))
    
    _, end_points = net(rgb_input, is_training=False, dropout_keep_prob=1.0)
    end_feature = end_points['avg_pool3d']
    sess = tf.Session()

    transform = torchvision.transforms.Compose([
        T.ToFloatTensorInZeroOne(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        T.Resize((224, 224)),
        T.CenterCrop((224, 224))
    ])

    # rgb_input = tf.placeholder(tf.float32, shape=(1, _SAMPLE_VIDEO_FRAMES, _IMAGE_SIZE, _IMAGE_SIZE, 3))
    # with tf.variable_scope('RGB'):
    #   rgb_model = i3d.InceptionI3d(
    #       400, spatial_squeeze=True, final_endpoint='Logits')
    #   rgb_logits, _ = rgb_model(
    #       rgb_input, is_training=False, dropout_keep_prob=1.0)

    rgb_variable_map = {}
    for variable in tf.global_variables():

      if variable.name.split('/')[0] == 'RGB':
          rgb_variable_map[variable.name.replace(':0', '')] = variable
    rgb_saver = tf.train.Saver(var_list=rgb_variable_map, reshape=True)

    rgb_saver.restore(sess, _CHECKPOINT_PATHS['rgb_imagenet'])
    
    video_list = open(VIDEO_PATH_FILE).readlines()
    video_list = [name.strip() for name in video_list]
    print('video_list', video_list)
    if not os.path.isdir(OUTPUT_FEAT_DIR):
        os.mkdir(OUTPUT_FEAT_DIR)

    print('Total number of videos: %d'%len(video_list))

    for cnt, video_name in enumerate(video_list):
        # video_path = os.path.join(VIDEO_DIR, video_name)
        video_path = os.path.join(VIDEO_DIR, video_name+'.avi')
        feat_path = os.path.join(OUTPUT_FEAT_DIR, video_name + '.npy')

        if os.path.exists(feat_path):
            print('Feature file for video %s already exists.'%video_name)
            continue

        print('video_path', video_path)

        vframes, _, info = torchvision.io.read_video(video_path, start_pts=0, end_pts=None, pts_unit='sec')
        vframes = T.frame_temporal_sampling(vframes,start_idx=0,end_idx=None,num_samples=int(round(len(vframes)/info['video_fps']*24)))
        vframes = transform(vframes).permute(1, 2, 3, 0).numpy()
        n_frame = vframes.shape[0]

        print('Total frames: %d'%n_frame)

        features = []

        n_feat = int(n_frame // 8)
        n_batch = n_feat // batch_size + 1
        print('n_frame: %d; n_feat: %d'%(n_frame, n_feat))
        print('n_batch: %d'%n_batch)

        for i in range(n_batch):
            input_blobs = []
            for j in range(batch_size):
                start_idx = i*batch_size*L + j*L if i==0 else i*batch_size*L + j*L - 8
                end_idx = min(n_frame, start_idx+L)
                input_blob = vframes[start_idx:end_idx].reshape(-1, resize_w, resize_h, 3)

                # input_blob = []
                # for k in range(L):
                #     idx = i*batch_size*L + j*L + k
                #     idx = int(idx)
                #     idx = idx%n_frame + 1
                #     frame = vframes[idx-1]
                #     # image = Image.open(os.path.join('/data/home2/hacker01/Share/Data/TACoS/images_256p/{}'.format(video_name), '%d.jpg'%idx))
                #     # image = image.resize((resize_w, resize_h))
                #     # image = np.array(image, dtype='float32')
                #     '''
                #     image[:, :, 0] -= 104.
                #     image[:, :, 1] -= 117.
                #     image[:, :, 2] -= 123.
                #     '''
                #     # image[:, :, :] -= 127.5
                #     # image[:, :, :] /= 127.5
                #     input_blob.append(frame)
                #
                # input_blob = np.array(input_blob, dtype='float32')

                input_blobs.append(input_blob)

            input_blobs = np.array(input_blobs, dtype='float32')

            clip_feature = sess.run(end_feature, feed_dict={rgb_input: input_blobs})
            clip_feature = np.reshape(clip_feature, (-1, clip_feature.shape[-1]))
            
            features.append(clip_feature)

        features = np.concatenate(features, axis=0)
        # features = features[:n_feat:2]   # 16 frames per feature  (since 64-frame snippet corresponds to 8 features in I3D)

        feat_path = os.path.join(OUTPUT_FEAT_DIR, video_name + '.npy')

        print('Saving features and probs for video: %s ...'%video_name)
        np.save(feat_path, features)
        
        print('%d: %s has been processed...'%(cnt, video_name))