def export_all_contours(contours, data_path):
    print('\nProcessing {:d} images and labels ...\n'.format(len(contours)))
    images = np.zeros((len(contours), SIZE, SIZE, 1))
    masks = np.zeros((len(contours), SIZE, SIZE, 1))
    for idx, contour in enumerate(contours):
        img, mask = read_contour(contour, data_path)
        if img.shape[0] > SIZE:
            img = center_crop(img, SIZE)
            mask = center_crop(mask, SIZE)
        images[idx] = img
        masks[idx] = mask

    return images, masks
 def forward(self, x):
     x = self.fully_connected(x)  # 4096
     x = x.reshape((-1, 256, 4, 4))  # 4x4x256
     x = self.deconv(x)  # 256x256x3
     x = center_crop(x, self.deconv_output_size,
                     self.desired_output_size)  # 227x227x3
     return x
 def forward(self, image, features):
     cropped_img = center_crop(image, current_size=227,
                               desired_size=224)  # 224x224x3
     x1 = self.conv(cropped_img)  # 1x1x256
     x1 = torch.flatten(x1, 1)  # 256
     x2 = self.features_fc(features)  # 512
     x = torch.cat((x1, x2), dim=1)  # 768
     x = self.fc(x)  # 1
     return x
Exemplo n.º 4
0
    def load_eval_dataset(self):
        (_, _), (x_test, self.y_test) = self.args.dataset.load_data()
        image_size = x_test.shape[1]
        x_test = np.reshape(x_test,[-1, image_size, image_size, 1])
        x_test = x_test.astype('float32') / 255
        x_eval = np.zeros([x_test.shape[0], *self.train_gen.input_shape])
        for i in range(x_eval.shape[0]):
            x_eval[i] = center_crop(x_test[i])

        self.x_test = x_eval
Exemplo n.º 5
0
    def forward(self, x):
        crop_h, crop_w = int(x.size()[-2]), int(x.size()[-1])
        x = self.stages[0](x)

        side = []
        side_out = []
        for i in range(1, len(self.stages)):
            x = self.stages[i](x)
            side_temp = self.side_prep[i - 1](x)
            side.append(
                center_crop(self.upscale[i - 1](side_temp), crop_h, crop_w))
            side_out.append(
                center_crop(
                    self.upscale_[i - 1](self.score_dsn[i - 1](side_temp)),
                    crop_h, crop_w))

        out = torch.cat(side[:], dim=1)
        out = self.fuse(out)
        side_out.append(out)
        return side_out
 def img_process(self, imgs, frames_num): # imgs = 현재 지정된 프레임, frames_num = 네트워크에 들어갈 프레임 수
     images = np.zeros((frames_num, 224, 224, 3))
     orig_imgs = np.zeros_like(images)
     for i in range(frames_num):
         next_image = imgs[i]
         next_image = np.uint8(next_image)
         scaled_img = cv2.resize(next_image, (256, 256), interpolation=cv2.INTER_LINEAR)  # resize to 256x256
         cropped_img = center_crop(scaled_img)  # center crop 224x224
         final_img = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB)
         images[i] = final_img  # opencv와 달리 pytorch에선 rgb를 쓰기때문에 포맷이 변경된 이미지 사용.
         orig_imgs[i] = cropped_img # opencv 는 BGR을 쓰기 때문에 포맷 유지 
     torch_imgs = torch.from_numpy(images.transpose(3, 0, 1, 2)) # 채널/ 프레임수/ 너비/ 높이
     torch_imgs = torch_imgs.float() / 255.0
     mean_3d = [124 / 255, 117 / 255, 104 / 255]
     std_3d = [0.229, 0.224, 0.225]
     for t, m, s in zip(torch_imgs, mean_3d, std_3d):
         t.sub_(m).div_(s)
     return np.expand_dims(orig_imgs, 0), torch_imgs.unsqueeze(0)  # return opencv용 원본이미지, torch용 이미지
Exemplo n.º 7
0
def layer_wise_relevance_propagation(conf):

    img_dir = conf["paths"]["image_dir"]
    res_dir = conf["paths"]["results_dir"]

    image_height = conf["image"]["height"]
    image_width = conf["image"]["width"]

    lrp = RelevancePropagation(conf)

    image_paths = list()
    for (dirpath, dirnames, filenames) in os.walk(img_dir):
        image_paths += [os.path.join(dirpath, file) for file in filenames]

    for i, image_path in enumerate(image_paths):
        print("Processing image {}".format(i+1))
        image = center_crop(np.array(Image.open(image_path)), image_height, image_width)
        relevance_map = lrp.run(image)
        plot_relevance_map(image, relevance_map, res_dir, i)
Exemplo n.º 8
0
def stylize_output(wct_model, content_img, styles):
    if args.crop_size > 0:
        styles = [center_crop(style) for style in styles]
    if args.keep_colors:
        styles = [preserve_colors_np(style, content_img) for style in styles]

    # Run the frame through the style network
    stylized = content_img
    for _ in range(args.passes):
        stylized = wct_model.predict(stylized, styles)

    # Stitch the style + stylized output together, but only if there's one style image
    if args.concat:
        # Resize style img to same height as frame
        style_img_resized = scipy.misc.imresize(
            styles[0], (stylized.shape[0], stylized.shape[0]))
        stylized = np.hstack([style_img_resized, stylized])

    return stylized
 def img_process(self, imgs, frames_num):
     images = np.zeros((frames_num, 224, 224, 3))
     orig_imgs = np.zeros_like(images)
     for i in range(frames_num):
         next_image = imgs[i]
         next_image = np.uint8(next_image)
         scaled_img = cv2.resize(
             next_image, (256, 256),
             interpolation=cv2.INTER_LINEAR)  # resize to 256x256
         cropped_img = center_crop(scaled_img)  # center crop 224x224
         final_img = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB)
         images[i] = final_img
         orig_imgs[i] = cropped_img
     torch_imgs = torch.from_numpy(images.transpose(3, 0, 1, 2))
     torch_imgs = torch_imgs.float() / 255.0
     mean_3d = [124 / 255, 117 / 255, 104 / 255]
     std_3d = [0.229, 0.224, 0.225]
     for t, m, s in zip(torch_imgs, mean_3d, std_3d):
         t.sub_(m).div_(s)
     return np.expand_dims(orig_imgs, 0), torch_imgs.unsqueeze(0)
Exemplo n.º 10
0
    def __getitem__(self, idx):
        self.current_item_path = self.inputs[idx]
        input_img = correct_dims(nib.load(self.inputs[idx]).get_data())
        label_img = gen_mask(self.labels[idx])

        # Resize to input image and label to size (self.size x self.size x self.size)
        if self.sampling_mode == "resize":
            ex, label = resize_img(input_img, label_img, self.size)

        # Constant center-crop sample of size (self.size x self.size x self.size)
        elif self.sampling_mode == 'center':
            ex, label = center_crop(input_img, label_img, self.size)

        # Find centers of lesion masks and crop image to include them
        # to measure consistent validation performance with small crops
        elif self.sampling_mode == "center_val":
            ex, label = find_and_crop_lesions(input_img, label_img, self.size,
                                              self.deterministic)

        # Randomly crop sample of size (self.size x self.size x self.size)
        elif self.sampling_mode == "random":
            ex, label = random_crop(input_img, label_img, self.size)

        else:
            print("Invalid sampling mode.")
            exit()

        ex = np.divide(ex, 255.0)
        label = np.array([(label > 0).astype(int)]).squeeze()
        # (experimental) APPLY RANDOM FLIPPING ALONG EACH AXIS
        if not self.deterministic:
            for i in range(3):
                if random() > 0.5:
                    ex = np.flip(ex, i)
                    label = np.flip(label, i)
        inputs = torch.from_numpy(ex.copy()).type(
            torch.FloatTensor).unsqueeze(0)
        labels = torch.from_numpy(label.copy()).type(
            torch.FloatTensor).unsqueeze(0)
        return inputs, labels
Exemplo n.º 11
0
def pre_process_input_data(data, sample=None):
    '''Apply transformation fuctions (normalization, 
    sampling frames and center cropping) on a video array. 

    Args:
        data--> a 4D array of video frames
        sample--> sample a fixed number of frames from the video (int)
    
    Returns:
         data --> a tranformed 5D data array with batch size   
    '''

    data = normalize(np.array(data))

    if sample is not None:
        data, _ = get_random_frames(data, label=0, num_frames=sample)

    data, _ = center_crop(data,
                          target_size=(config.get('HEIGHT'),
                                       config.get('WIDTH')))
    data = np.expand_dims(data, axis=0)

    return data.astype(np.float32)
Exemplo n.º 12
0
def evaluate(model_path, file_dir, meta_file):
    '''Evaluates a trained model on the test samples and outputs the 
    overall and average accuracy values.
    Args:
        model_path--> path to the keras saved model
        file_dir --> path to the video files
        meta_file --> meta file containing the test video file names and labels
    Returns:
        report--> returns a text report showing main classification metrics using
                scikit-learn classification report method
    '''

    loaded_model = K.models.load_model(model_path, custom_objects={'B': B})
    test_data, test_label = get_file_list(file_dir,
                                          meta_file,
                                          file_type='.mp4')
    test_prediction = []

    for x, y in zip(test_data, test_label):
        video = read_video(x)
        video = normalize(video)
        video, _ = center_crop(video,
                               target_size=(config_data.get('HEIGHT'),
                                            config_data.get('WIDTH')))
        prob = get_pridiction(loaded_model, video)
        test_prediction.append(prob[0])

    label_map = read_json(LABEL_MAP_FILE)
    vocabulary = [
        label_map.get(str(i)) for i in range(config_data.get('NUM_CLASSES'))
    ]
    report = classification_report(test_label,
                                   test_prediction,
                                   target_names=vocabulary,
                                   zero_division=1)

    return report
Exemplo n.º 13
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints, 
                                relu_targets=args.relu_targets,
                                vgg_path=args.vgg_path, 
                                device=args.device,
                                ss_patch_size=args.ss_patch_size, 
                                ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if os.path.isdir(args.in_path):
        in_path = get_files(args.in_path)
    else: # Single image file
        in_path = [args.in_path]

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
    else: # Single image file
        style_files = [args.style_path]

    print(style_files)
    import time
    # time.sleep(999)

    in_args = [
        'ffmpeg',
        '-i', args.in_path,
        '%s/frame_%%d.png' % in_dir
    ]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    


    s = time.time()
    for content_fullpath in in_path:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)


        try:

            for style_fullpath in style_files:
                style_img = get_img(style_fullpath)
                if args.style_size > 0:
                    style_img = resize_to(style_img, args.style_size)
                if args.crop_size > 0:
                    style_img = center_crop(style_img, args.crop_size)

                style_prefix, _ = os.path.splitext(style_fullpath)
                style_prefix = os.path.basename(style_prefix)

                # print("ARRAY:  ", style_img)
                out_v = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
                print("OUT:",out_v)
                if os.path.isfile(out_v):
                    print("SKIP" , out_v)
                    continue
                
                for in_f, out_f in zip(in_files, out_files):
                    print('{} -> {}'.format(in_f, out_f))
                    content_img = get_img(in_f)

                    if args.keep_colors:
                        style_rgb = preserve_colors_np(style_img, content_img)
                    else:
                        style_rgb = style_img

                    stylized = wct_model.predict(content_img, style_rgb, args.alpha, args.swap5, args.ss_alpha)

                    if args.passes > 1:
                        for _ in range(args.passes-1):
                            stylized = wct_model.predict(stylized, style_rgb, args.alpha)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(style_rgb, (stylized.shape[0], stylized.shape[0]))
                        stylized = np.hstack([style_img_resized, stylized])

                    save_img(out_f, stylized)

                fr = 30
                out_args = [
                    'ffmpeg',
                    '-i', '%s/frame_%%d.png' % out_dir,
                    '-f', 'mp4',
                    '-q:v', '0',
                    '-vcodec', 'mpeg4',
                    '-r', str(fr),
                    '"' + out_v + '"'
                ]
                print(out_args)

                subprocess.call(" ".join(out_args), shell=True)
                print('Video at: %s' % out_v)

                if args.keep_tmp is True or len(style_files) > 1:
                    continue
                else:
                    shutil.rmtree(args.tmp_dir)
                print('Processed in:',(time.time() - s))

            print('Processed in:',(time.time() - s))
 
        except Exception as e:
            print("EXCEPTION: ",e)
Exemplo n.º 14
0
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints, 
                                relu_targets=args.relu_targets,
                                vgg_path=args.vgg_path, 
                                device=args.device,
                                ss_patch_size=args.ss_patch_size, 
                                ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else: # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else: # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    ### Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)
        
        for style_fullpath in style_files: 
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(style_prefix)  # Extract filename prefix without ext

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

            style_img = get_img(style_fullpath)

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)

            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes-1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'
            
            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(count, time.time() - start))
Exemplo n.º 15
0
    def __getitem__(self, index):
        subj_file = self.datafiles[index]

        with h5py.File(self.dirname + subj_file, 'r') as fdset:
            h5data = fdset['kspace']
            sh = h5data.shape
            sliceindex = np.random.randint(0, sh[0])
            ksp_sli = h5data[sliceindex]

        C, H, W = ksp_sli.shape

        # Rotate ksp
        img_sli = np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(ksp_sli,
                                                                 axes=(1, 2)),
                                                axes=(1, 2)),
                                   axes=(1, 2))
        ksp_sli = np.fft.fftshift(np.fft.fft2(img_sli, axes=(1, 2)),
                                  axes=(1, 2))

        if self.rss == 'norm_acl':
            #mean, std = mean_std_aclines(ksp_sli, 10)
            i_min, i_max = min_max_aclines(ksp_sli, 10)
            img_sli_singlec = np.sqrt(
                np.sum(np.square(np.abs(img_sli)), axis=0))
            #norm_img_sli_singlec = (img_sli_singlec-mean)/std
            #print(norm_img_sli_singlec.max(), norm_img_sli_singlec.mean(), norm_img_sli_singlec.min())
            norm_img_sli_singlec = (img_sli_singlec - i_min) / (i_max - i_min)

        elif not self.rss:
            try:
                est_coilmap = np.load(self.coil_path + subj_file +
                                      str(sliceindex) + '.npy')
            except:
                print('CANT FIND:',
                      self.coil_path + subj_file + str(sliceindex),
                      ' Creating new...')
                est_coilmap_gpu = mr.app.EspiritCalib(ksp_sli,
                                                      calib_width=W,
                                                      thresh=0.02,
                                                      kernel_width=6,
                                                      crop=0.01,
                                                      max_iter=100,
                                                      show_pbar=False,
                                                      device=-1).run()
                est_coilmap = np.fft.fftshift(est_coilmap_gpu, axes=(1, 2))
                np.save(self.coil_path + subj_file + str(sliceindex),
                        est_coilmap)

            if np.isnan(np.min(est_coilmap)):
                print('COILMAP Contains nan. Create new...')
                est_coilmap_gpu = mr.app.EspiritCalib(ksp_sli,
                                                      calib_width=W,
                                                      thresh=0.02,
                                                      kernel_width=6,
                                                      crop=0.01,
                                                      max_iter=100,
                                                      show_pbar=False,
                                                      device=-1).run()
                est_coilmap = np.fft.fftshift(est_coilmap_gpu, axes=(1, 2))
                np.save(self.coil_path + subj_file + str(sliceindex),
                        est_coilmap)
                if np.isnan(np.min(est_coilmap)):
                    print('COILMAP still contains nan...')

            # Normalise
            img_sli_singlec = np.sum(
                np.conjugate(est_coilmap) * img_sli, axis=0) / (np.sum(
                    np.conjugate(est_coilmap) * est_coilmap, axis=0) + 1e-6)
            norm_fac = 1 / (np.percentile(
                np.abs(img_sli_singlec).flatten(), 80))
            norm_img_sli_singlec = norm_fac * np.abs(img_sli_singlec) * np.exp(
                1j * np.angle(img_sli_singlec))  #
        else:
            img_sli_singlec = np.sqrt(
                np.sum(np.square(np.abs(img_sli)), axis=0))
            norm_fac = 1 / (np.percentile(
                np.abs(img_sli_singlec).flatten(), 80))
            norm_img_sli_singlec = norm_fac * np.abs(img_sli_singlec)

        if self.crop:
            norm_img_sli_singlec = center_crop(norm_img_sli_singlec) + 1j * 0

        if self.noise > 0:
            norm_img_sli_singlec = (
                np.abs(norm_img_sli_singlec) +
                np.random.normal(loc=0,
                                 scale=1 / self.noise,
                                 size=norm_img_sli_singlec.shape)) * np.exp(
                                     1j * np.angle(norm_img_sli_singlec))

        if np.isnan(np.min(norm_img_sli_singlec)):
            norm_img_sli_singlec = np.zeros_like(norm_img_sli_singlec)

        # Create Patches
        subj_tens = torch.from_numpy(norm_img_sli_singlec.real).type(
            torch.complex64) + 1j * torch.from_numpy(
                norm_img_sli_singlec.imag).type(torch.complex64)
        patches = subj_tens.unfold(0, self.patchsize, self.patchsize).unfold(
            1, self.patchsize, self.patchsize)
        patches = torch.cat(
            patches.unbind())  # Patches is now [Num_patches, 28,28]

        return patches
Exemplo n.º 16
0
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else:  # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else:  # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    # Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(
            content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)

        for style_fullpath in style_files:
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(
                style_prefix)  # Extract filename prefix without ext

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

            style_img = get_img(style_fullpath)

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)

            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img,
                                             args.alpha, args.swap5,
                                             args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes - 1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'

            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(
        count,
        time.time() - start))
Exemplo n.º 17
0
def freeze_graph_test(pb_file, content_path, style_path):
    '''
    :param pb_path:pb文件的路径
    :param image_path:测试图片的路径
    :return:
    '''
    content_prefix, content_ext = os.path.splitext(content_path)
    content_prefix = os.path.basename(
        content_prefix)  # Extract filename prefix without ext
    style_prefix, _ = os.path.splitext(style_path)
    style_prefix = os.path.basename(
        style_prefix)  # Extract filename prefix without ext

    # 读取测试图片content_img和style_img
    content_img = get_img(content_path)
    if args.content_size > 0:
        content_img = resize_to(content_img, args.content_size)

    style_img = get_img(style_path)

    if args.style_size > 0:
        style_img = resize_to(style_img, args.style_size)
    if args.crop_size > 0:
        style_img = center_crop(style_img, args.crop_size)

    # if args.keep_colors:
    #     style_img = preserve_colors_np(style_img, content_img)

    content_img = preprocess(content_img)
    style_img = preprocess(style_img)

    # 打印检查点所有的变量
    chkp.print_tensors_in_checkpoint_file("models/inference/model.ckpt",
                                          tensor_name='',
                                          all_tensors=False)

    # for i in range(5):
    #     print("=" * 10, i + 1)
    #     chkp.print_tensors_in_checkpoint_file("models/relu" + str(i+1) + "_1/model.ckpt-15002", tensor_name='', all_tensors=True)

    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        # with open(pb_file, "rb") as f:
        #     x = f.read()
        #     # print(x)
        #     output_graph_def.ParseFromString(x)
        #     output_graph_def.ParseFromString(f.read())
        #     tf.import_graph_def(output_graph_def, name="")
        #     for i, n in enumerate(output_graph_def.node):
        #         print("Name of the node - %s" % n.name)
        #         print(n)

        with tf.Session() as sess:
            # sess.run(tf.global_variables_initializer())
            # tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)

            # 定义输入的张量名称,对应网络结构的输入张量
            content_input_tensor = sess.graph.get_tensor_by_name(
                "encoder_decoder_relu5_1/content_encoder_relu5_1/content_imgs:0"
            )
            style_input_tensor = sess.graph.get_tensor_by_name("style_img:0")

            # 定义输出的张量名称
            output_tensor_name = sess.graph.get_tensor_by_name(
                "encoder_decoder_relu5_1/decoder_relu5_1/decoder_model_relu5_1/relu5_1_16/relu5_1_16/BiasAdd:0"
            )

            # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
            stylized_rgb = sess.run(output_tensor_name,
                                    feed_dict={
                                        content_input_tensor: content_img,
                                        style_input_tensor: style_img
                                    })

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            save_img(out_f, stylized_rgb)
Exemplo n.º 18
0
import glob
import os
import utils
import scipy.misc

PATH_IN = '../celebA'
PATH_OUT = '../celeba_norm'

os.makedirs(PATH_OUT, exist_ok=True)

data_files = glob.glob(os.path.join(PATH_IN, "*.jpg"))


for pin in data_files:
    pout = pin.replace('celebA', 'celeba_norm')
    f = scipy.misc.imread(pin)
    f = utils.center_crop(f, 108, 64)
    scipy.misc.imsave(pout, f)
    print(pout)


print('Done')
Exemplo n.º 19
0
def loop(data_info, config, split_stack=True, test_time=True):
    up_thres, low_thres = config.up_thres, config.low_thres
    all_ims = []
    for info in data_info:
        depth_in_path, depth_ref_path, color_path, mask_path = info
        name = os.path.basename(depth_in_path)
        raw = Image.open(depth_in_path)
        gt = Image.open(depth_ref_path)
        rgb = Image.open(color_path).convert('L').resize(raw.size)
        mask = Image.open(mask_path)
        assert raw.size == gt.size, 'gt size not match raw size!'

        # Do center crop here
        if not split_stack:
            if config.image_size < min(raw.size):
                raw, gt, rgb, mask = utils.center_crop(raw, gt, rgb, mask, config.image_size)
            elif config.image_size >= max(raw.size):
                raw, gt, rgb, mask = utils.center_pad(raw, gt, rgb, mask, config.image_size)
            else:
                raise NotImplementedError('invalid config.image_size.')

        mask = np.array(mask, dtype=np.float32) / 255.0
        rgb = np.array(rgb, dtype=np.float32) / (127.0 - 1.0) * mask
        thres_range = (up_thres - low_thres) / 2.0
        raw = np.clip(np.array(raw, dtype=np.float32), low_thres, up_thres)
        gt = np.clip(np.array(gt, dtype=np.float32), low_thres, up_thres)
        if config.dnnet == "bilateral":
            raw = cv2.bilateralFilter(raw, 9, 75, 75)
        raw = (raw - low_thres) / thres_range - 1.0
        gt = (gt - low_thres) / thres_range - 1.0

        if split_stack:
            raw_arr, H, W = utils.split_patch(raw, config.image_size)
            gt_arr, _, _ = utils.split_patch(gt, config.image_size)
            rgb_arr, _, _ = utils.split_patch(rgb, config.image_size)
            all_ims.append((name, raw_arr, gt_arr, rgb_arr))
        else:
            all_ims.append((name, raw, gt, rgb, mask))

    ckpt_history = list()
    if test_time and not split_stack:
        path = tf.train.latest_checkpoint(config.checkpoint_dir)
        print("Load checkpoint: {}".format(path))
        params = build_model(config.image_size, config.image_size, config)
        sess = tf.Session()
        load_from_checkpoint(sess, path)

        tt_time = 0.0
        history_len = 3
        t_history = np.zeros(history_len, dtype=np.float32)
        for i, (name, raw, _, rgb, mask) in enumerate(all_ims):
            t_elapsed = loop_body_patch_time(sess, name, params, raw, rgb, mask, config)
            t_history[i % history_len] = t_elapsed
            tt_time += t_elapsed
            avg_time = 1000 * tt_time / (i + 1)  # ms
            mv_avg_time = 1000 * np.mean(t_history)
            print('iter {} | tt_time: {:.4f}s; avg_time: {:.2f}; mv_avg_time: {:.2f}'.format(i+1, tt_time, avg_time, mv_avg_time))
        tf.reset_default_graph()
    else:
        while True:
            # Wait until new checkpoint exist when training phase is not finished.
            path = wait_for_new_checkpoint(config.checkpoint_dir, ckpt_history)
            print("Loading from checkpoint: {}".format(path))

            if split_stack:
                print('evaluating {} imgs'.format(len(all_ims)))
                for i, (name, raw_arr, gt_arr, rgb_arr) in enumerate(all_ims):
                    loop_body_whole(i, path, raw_arr, gt_arr, rgb_arr, H, W, config)
            else:
                for i, (name, raw, gt, rgb, mask) in enumerate(all_ims):
                    loop_body_patch(i, path, raw, gt, rgb, mask, config)

            if not config.loop: break
Exemplo n.º 20
0
# image sequence numbering
frame_seq = change_fps(5, frames, length)
print('Total New Frame:', len(frame_seq))
frame_n = 0
i = 0

while (cap.isOpened()):
    # get frame-by-frame
    hasVideo, frame = cap.read()
    if hasVideo == True and frame_n == round(frame_seq[i]):
        # resize
        #frame = cv2.resize(frame, (1382, 512), interpolation=cv2.INTER_AREA)

        # crop image
        frame = center_crop(frame, (1382, 512))

        # display video
        cv2.imshow('Frame', frame)
        # write frame to image
        cv2.imwrite(f'sequence_04/{i:06d}.png', frame)
        # presss q to break
        if cv2.waitKey(25) & 0xFF == ord('q'):
            break
        i = i + 1
    frame_n = frame_n + 1
    if frame_n == frames:
        break

# release video
cap.release()
Exemplo n.º 21
0
def main():
    # start = time.time()
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else:  # Single image file
        content_files = [args.content_path]
    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else:  # Single image file
        style_files = [args.style_path]

    os.makedirs(args.out_path, exist_ok=True)

    thetotal = len(content_files) * len(style_files)

    count = 1

    ### Apply each style to each content image
    for content_fullpath in content_files:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(
            content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = resize_to(content_img, args.content_size)

        for style_fullpath in style_files:

            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(
                style_prefix)  # Extract filename prefix without ext

            if args.mask:
                mask_prefix_, _ = os.path.splitext(args.mask)
                mask_prefix = os.path.basename(mask_prefix_)

            if args.keep_colors:
                style_prefix = "KPT_" + style_prefix
            if args.concat:
                style_prefix = "CON_" + style_prefix
            if args.adain:
                style_prefix = "ADA_" + style_prefix
            if args.swap5:
                style_prefix = "SWP_" + style_prefix
            if args.mask:
                style_prefix = "MSK_" + mask_prefix + '_' + style_prefix
            if args.remaster:
                style_prefix = style_prefix + '_REMASTERED'

            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            if os.path.isfile(out_f):
                print("SKIP", out_f)
                count += 1
                continue

            # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
            # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])
            style_img = get_img(style_fullpath)
            if style_img == ("IMAGE IS BROKEN"):
                continue

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)
            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img,
                                             args.alpha, args.swap5,
                                             args.ss_alpha, args.adain)

            if args.passes > 1:
                for _ in range(args.passes - 1):
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

            if args.mask:
                import cv2
                cv2.imwrite('./tmp.png', stylized_rgb)
                stylized_rgb = cv2.imread('./tmp.png')

                mask = cv2.imread(args.mask, cv2.IMREAD_GRAYSCALE)

                # from scipy.misc import bytescale
                # mask = bytescale(mask)
                # mask = scipy.ndimage.imread(args.mask,flatten=True,mode='L')
                height, width = stylized_rgb.shape[:2]
                # print(height, width)

                # Resize the mask to fit the image.
                mask = scipy.misc.imresize(mask, (height, width),
                                           interp='bilinear')
                stylized_rgb = cv2.bitwise_and(stylized_rgb,
                                               stylized_rgb,
                                               mask=mask)

            # Stitch the style + stylized output together, but only if there's one style image
            if args.concat:
                # Resize style img to same height as frame
                # style_prefix = style_prefix + "CON_"
                # content_img_resized = scipy.misc.imresize(content_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                style_img_resized = scipy.misc.imresize(
                    style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'
            out_f = os.path.join(
                args.out_path, '{}_{}{}'.format(content_prefix, style_prefix,
                                                content_ext))

            if args.remaster:
                # outf = os.path.join(args.out_path, '{}_{}_REMASTERED{}'.format(content_prefix, style_prefix, content_ext))
                stylized_rgb = remaster_pic(stylized_rgb)
            save_img(out_f, stylized_rgb)
            totalfiles = len([
                name for name in os.listdir(args.out_path)
                if os.path.isfile(os.path.join(args.out_path, name))
            ])
            # percent = math.floor(float(totalfiles/thetotal))
            print("{}/{} TOTAL FILES".format(count, thetotal))
            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))
Exemplo n.º 22
0
def main():
    start = time.time()

    session_conf = tf.ConfigProto(
        allow_soft_placement=args.allow_soft_placement,
        log_device_placement=args.log_device_placement)
    session_conf.gpu_options.per_process_gpu_memory_fraction = args.gpu_fraction

    with tf.Graph().as_default():
        # with tf.Session(config=session_conf) as sess:
        # Load the WCT model
        wct_model = WCT(checkpoints=args.checkpoints,
                        relu_targets=args.relu_targets,
                        vgg_path=args.vgg_path,
                        device=args.device,
                        ss_patch_size=args.ss_patch_size,
                        ss_stride=args.ss_stride)

        with wct_model.sess as sess:
            # 训练的时候不需要style_img,所以在inference的时候重新保存一次checkpoint
            saver = tf.train.Saver()
            log_path = args.log_path if args.log_path is not None else os.path.join(
                args.checkpoint, 'log')
            summary_writer = tf.summary.FileWriter(log_path, sess.graph)

            # 部分node没有保存在checkpoint中,需要重新初始化
            # sess.run(tf.global_variables_initializer())

            # Get content & style full paths
            if os.path.isdir(args.content_path):
                content_files = get_files(args.content_path)
            else:  # Single image file
                content_files = [args.content_path]
            if os.path.isdir(args.style_path):
                style_files = get_files(args.style_path)
                if args.random > 0:
                    style_files = np.random.choice(style_files, args.random)
            else:  # Single image file
                style_files = [args.style_path]

            os.makedirs(args.out_path, exist_ok=True)

            count = 0

            # Apply each style to each content image
            for content_fullpath in content_files:
                content_prefix, content_ext = os.path.splitext(
                    content_fullpath)
                content_prefix = os.path.basename(
                    content_prefix)  # Extract filename prefix without ext

                content_img = get_img(content_fullpath)
                if args.content_size > 0:
                    content_img = resize_to(content_img, args.content_size)

                for style_fullpath in style_files:
                    style_prefix, _ = os.path.splitext(style_fullpath)
                    style_prefix = os.path.basename(
                        style_prefix)  # Extract filename prefix without ext

                    # style_img = get_img_crop(style_fullpath, resize=args.style_size, crop=args.crop_size)
                    # style_img = resize_to(get_img(style_fullpath), content_img.shape[0])

                    style_img = get_img(style_fullpath)

                    if args.style_size > 0:
                        style_img = resize_to(style_img, args.style_size)
                    if args.crop_size > 0:
                        style_img = center_crop(style_img, args.crop_size)

                    if args.keep_colors:
                        style_img = preserve_colors_np(style_img, content_img)

                    # if args.noise:  # Generate textures from noise instead of images
                    #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
                    #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

                    # Run the frame through the style network
                    stylized_rgb = wct_model.predict(content_img, style_img,
                                                     args.alpha, args.swap5,
                                                     args.ss_alpha, args.adain)

                    if args.passes > 1:
                        for _ in range(args.passes - 1):
                            stylized_rgb = wct_model.predict(
                                stylized_rgb, style_img, args.alpha,
                                args.swap5, args.ss_alpha, args.adain)

                    save_path = saver.save(
                        sess, os.path.join(args.checkpoint, 'model.ckpt'))
                    print("Model saved in file: %s" % save_path)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(
                            style_img,
                            (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                        # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                        stylized_rgb = np.hstack(
                            [style_img_resized, stylized_rgb])

                    # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
                    out_f = os.path.join(
                        args.out_path,
                        '{}_{}{}'.format(content_prefix, style_prefix,
                                         content_ext))
                    # out_f = f'{content_prefix}_{style_prefix}.{content_ext}'

                    # print(stylized_rgb, stylized_rgb.shape, type(stylized_rgb))
                    # print(out_f)
                    save_img(out_f, stylized_rgb)

                    count += 1
                    print("{}: Wrote stylized output image to {} at {}".format(
                        count, out_f, time.time()))
                    print("breaking...")
                    break
                break

            print("Finished stylizing {} outputs in {}s".format(
                count,
                time.time() - start))
Exemplo n.º 23
0
def main():
    start = time.time()

    # Load the WCT model
    wct_model = WCT(input_checkpoint=args.input_checkpoint,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    print("model construct end !")

    # Get content & style full paths
    if os.path.isdir(args.content_path):
        content_files = get_files(args.content_path)
    else: # Single image file
        content_files = [args.content_path]

    content_seg_files = list(map(lambda x: x.replace("2017", "2017_seg").replace("jpg", "png"), content_files))
    assert reduce(lambda a, b : a + b, map(lambda x: int(os.path.exists(x)),content_seg_files + content_files)) > 0

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
        if args.random > 0:
            style_files = np.random.choice(style_files, args.random)
    else: # Single image file
        style_files = [args.style_path]

    style_seg_files = list(map(lambda x: x.replace("2017", "2017_seg").replace("jpg", "png"), style_files))
    assert reduce(lambda a, b : a + b, map(lambda x: int(os.path.exists(x)),style_seg_files + style_files)) > 0

    os.makedirs(args.out_path, exist_ok=True)

    count = 0

    ### Apply each style to each content image
    for i in range(len(content_files)):
        content_fullpath = content_files[i]
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)  # Extract filename prefix without ext

        content_img = get_img(content_fullpath)
        if args.content_size > 0:
            content_img = np.asarray(Image.fromarray(content_img.astype(np.uint8)).resize((args.content_size, args.content_size))).astype(np.float32)
        content_seg = get_img_ori(content_seg_files[i])
        if args.content_size > 0:
            content_seg = np.asarray(Image.fromarray(content_seg.astype(np.uint8)).resize((args.content_size, args.content_size))).astype(np.float32)
            content_seg = image_to_pixel_map(content_seg)[...,np.newaxis]
        content_img = np.concatenate([content_img, content_seg], axis=-1)

        for i in range(len(style_files)):
            style_fullpath = style_files[i]
            style_prefix, _ = os.path.splitext(style_fullpath)
            style_prefix = os.path.basename(style_prefix)  # Extract filename prefix without ext

            style_img = get_img(style_fullpath)
            style_seg = get_img_ori(style_seg_files[i])

            if args.style_size > 0:
                style_img = resize_to(style_img, args.style_size)
                style_seg = resize_to(style_seg, args.style_size)

            if args.crop_size > 0:
                style_img = center_crop(style_img, args.crop_size)
                style_seg = center_crop(style_seg, args.crop_size)

            style_seg = image_to_pixel_map(style_seg)[...,np.newaxis]
            style_img = np.concatenate([style_img, style_seg], axis=-1)

            assert not args.keep_colors
            if args.keep_colors:
                style_img = preserve_colors_np(style_img, content_img)

            # if args.noise:  # Generate textures from noise instead of images
            #     frame_resize = np.random.randint(0, 256, frame_resize.shape, np.uint8)
            #     frame_resize = gaussian_filter(frame_resize, sigma=0.5)

            # Run the frame through the style network
            stylized_rgb = wct_model.predict(content_img, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)


            if args.passes > 1:
                for _ in range(args.passes-1):
                    stylized_rgb = np.concatenate([stylized_rgb ,content_seg], axis=-1)
                    stylized_rgb = wct_model.predict(stylized_rgb, style_img, args.alpha, args.swap5, args.ss_alpha, args.adain)

            # Stitch the style + stylized output together, but only if there's one style image
            assert not args.concat
            if args.concat:
                # Resize style img to same height as frame
                style_img_resized = scipy.misc.imresize(style_img, (stylized_rgb.shape[0], stylized_rgb.shape[0]))
                # margin = np.ones((style_img_resized.shape[0], 10, 3)) * 255
                stylized_rgb = np.hstack([style_img_resized, stylized_rgb])

            ####+++++++++++++++++++++++++++++++++++

            # Format for out filename: {out_path}/{content_prefix}_{style_prefix}.{content_ext}
            out_f = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
            save_img(out_f, stylized_rgb)

            count += 1
            print("{}: Wrote stylized output image to {}".format(count, out_f))

    print("Finished stylizing {} outputs in {}s".format(count, time.time() - start))
Exemplo n.º 24
0
    '/volume/annahung-project/image_generation/draw-the-music/annadraw/dataset/wikiart/*/*'
)
print('dataset:', len(dataset))

names = verify(dataset)

folds = 5

random.shuffle(names)

images = {}

count = len(names)
print("Count: %d" % count)
count_per_fold = count // folds

i = 0
im = 0
for imgfile in tqdm(names):
    image = center_crop(imageio.imread(imgfile))
    images[imgfile] = image
    im += 1

    if im == count_per_fold:
        output = open(dataset_name + '_data_fold_%d.pkl' % i, 'wb')
        pickle.dump(list(images.values()), output)
        output.close()
        i += 1
        im = 0
        images.clear()
Exemplo n.º 25
0
def run_inference(subj, R, mode, k, num_sampels, num_bootsamles, batch_size,
                  num_iter, step_size, phase_step, complex_rec, use_momentum,
                  log, device):
    # Some inits of paths... Edit these
    vae_model_name = 'T2-20210415-111101/450.pth'
    vae_path = '/cluster/scratch/jonatank/logs/ddp/vae/'
    data_path = '/cluster/work/cvl/jonatank/fastMRI_T2/validation/'
    log_path = '/cluster/scratch/jonatank/logs/ddp/restore/pytorch/'
    rss = True

    # Load pretrained VAE
    path = vae_path + vae_model_name
    vae = torch.load(path, map_location=torch.device(device))
    vae.eval()

    # Data loader setup
    subj_dataset = Subject(subj, data_path, R, rss=rss)
    subj_loader = data.DataLoader(subj_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)

    # Time model and init resulting matrices
    start_time = time.perf_counter()
    rec_subj = np.zeros((len(subj_loader), 320, 320))
    gt_subj = np.zeros((len(subj_loader), 320, 320))

    # Set basic parameters
    print('Subj: ', subj, ' R: ', R, ' mode: ', mode, ' k: ', k,
          ' num_sampels: ', num_sampels, ' num_bootsamles: ', num_bootsamles,
          ' batch_size: ', batch_size, ' num_iter: ', num_iter, ' step_size: ',
          step_size, ' phase_step: ', phase_step)

    # Log
    log_path = log_path + 'R' + str(R) + '_mode' + str(
        k) + mode + '_reg2lmb0.01_' + datetime.now().strftime("%Y%m%d-%H%M%S")
    if log:
        import wandb
        wandb.login()
        wandb.init(project='JDDP' + '_T2',
                   name=vae_model_name,
                   config={
                       "num_iter": num_iter,
                       "step_size": step_size,
                       "phase_step": phase_step,
                       "mode": mode,
                       'R': R,
                       'K': k,
                       'use_momentum': use_momentum
                   })
        #wandb.watch(vae)
    else:
        wandb = False

    print("num_iter", num_iter, " step_size ", step_size, " phase_step ",
          phase_step, " mode ", mode, ' R ', R, ' K ', k, 'use_momentum',
          use_momentum)

    for batch in tqdm(subj_loader, desc="Running inference"):
        ksp, coilmaps, rss, norm_fact, num_sli = batch

        rec_sli = vaerecon(ksp[0],
                           coilmaps[0],
                           mode,
                           vae,
                           rss[0],
                           log_path,
                           device,
                           writer=wandb,
                           norm=norm_fact.item(),
                           nsampl=num_sampels,
                           boot_samples=num_bootsamles,
                           k=k,
                           patchsize=28,
                           parfact=batch_size,
                           num_iter=num_iter,
                           stepsize=step_size,
                           lmb=phase_step,
                           use_momentum=use_momentum)

        rec_subj[num_sli] = np.abs(center_crop(rec_sli.detach().cpu().numpy()))
        gt_subj[num_sli] = np.abs(center_crop(rss[0]))

        rmse_sli = nmse(rec_subj[num_sli], gt_subj[num_sli])
        ssim_sli = ssim(rec_subj[num_sli], gt_subj[num_sli])
        psnr_sli = psnr(rec_subj[num_sli], gt_subj[num_sli])
        print('Slice: ', num_sli.item(), ' RMSE: ', str(rmse_sli), ' SSIM: ',
              str(ssim_sli), ' PSNR: ', str(psnr_sli))
        end_time = time.perf_counter()

        print(f"Elapsed time for {str(num_sli)} slices: {end_time-start_time}")

    rmse_v = nmse(recon_subj, gt_subj)
    ssim_v = nmse(recon_subj, gt_subj)
    psnr_v = nmse(recon_subj, gt_subj)
    print('Subject Done: ', 'RMSE: ', str(rmse_sli), ' SSIM: ', str(ssim_sli),
          ' PSNR: ', str(psnr_sli))

    pickle.dump(
        recon_subj,
        open(log_path + subj + str(k) + mode + str(restore_sense) + str(R),
             'wb'))

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(subj_loader)} slices: {end_time-start_time}")
Exemplo n.º 26
0
import cv2
import numpy as np
import utils
import glob

img_fn = glob.glob(
    '/home/didi/Repository/computervision/image-handling/traffic02.jpg')

enlarge = 1

crop_dim = (1242, 375)

for i, img_path in enumerate(img_fn):
    img = cv2.imread(img_path)
    img_resize = utils.resize(enlarge, img)
    img_crop = utils.center_crop(img_resize, crop_dim)

    cv2.imwrite(
        f'/home/didi/Repository/computervision/image-handling/{i:06d}.png',
        img_crop)

print(f"Kitti Size: (375, 1242)")
print(f"Original: {img.shape}")
print(f"Resize: {img_resize.shape}")
print(f"Crop: {img_crop.shape}")

cv2.imshow('resize', img_resize)
cv2.imshow('after_crop', img_crop)
cv2.imshow('before crop', img)

cv2.waitKey(0)
Exemplo n.º 27
0
    def refresh_images(self):
        curr_refims = []
        curr_refimids = []

        # add images until nrefims satisfied
        while len(curr_refims) < self._nrefims:
            n_to_show = self._nrefims - len(curr_refims)

            # dynamic reference image dir takes precedence
            if self._dynrefimdir and os.path.isdir(self._dynrefimdir):
                dynref_rps_added = []
                dynref_rps_unseen = set(os.listdir(
                    self._dynrefimdir)) - self._dynref_rps_seen
                invalid_rps = set(rp for rp in dynref_rps_unseen
                                  if os.path.splitext(rp)[1].lower() not in
                                  self.image_extensions)
                self._dynref_rps_seen |= invalid_rps
                dynref_rps_unseen -= invalid_rps
                n_dynref_to_show = n_to_show if self._max_n_dynref is None \
                    else min(n_to_show, self._max_n_dynref - len(curr_refims))
                if dynref_rps_unseen and n_dynref_to_show > 0:
                    for imfn, i in zip(dynref_rps_unseen,
                                       range(n_dynref_to_show)):
                        self._dynref_rps_seen.add(imfn)
                        imid = os.path.splitext(imfn)[0]
                        try:  # if already loaded and cached
                            curr_refims.append(
                                self._loaded_refims['dyn'][imid])
                            curr_refimids.append(imid)
                            dynref_rps_added.append(imfn)
                        except KeyError:
                            im = utils.read_image(
                                os.path.join(self._dynrefimdir, imfn))
                            if im is not None:
                                if self._crop_center:
                                    im = utils.center_crop(im)
                                self._check_add_to_buffer(im,
                                                          imid,
                                                          is_dyn=True)
                                curr_refims.append(im)
                                curr_refimids.append(imid)
                                dynref_rps_added.append(imfn)
                    if len(dynref_rps_added) > 0:
                        print('added dynref images:', dynref_rps_added)
                    continue  # check dyn ref again; only check stat ref if no dyn imids processed

            # static reference image dir
            for i in range(n_to_show):
                idx_toshow = self._all_view[self._i_toshow %
                                            self._n_all_refims]
                if self._all_refims_is_valid[idx_toshow]:
                    rp = self._all_refimrpaths[idx_toshow]
                    imid = self._all_refimids[idx_toshow]
                    try:
                        curr_refims.append(self._loaded_refims['stat'][imid])
                        curr_refimids.append(imid)
                    except KeyError:  # not found in loaded images
                        im = utils.read_image(os.path.join(self._refimdir, rp))
                        if im is None:  # not a valid image
                            self._all_refims_is_valid[idx_toshow] = False
                            print(
                                f'{self.__class__.__name__}{self._idx_str}: invalide reference image: {rp}'
                            )
                        else:
                            if self._crop_center:
                                im = utils.center_crop(im)
                            self._check_add_to_buffer(im, imid)
                            curr_refims.append(im)
                            curr_refimids.append(imid)

                # check change of epoch
                self._i_toshow += 1
                epoch = self._i_toshow / self._n_all_refims
                dynref_epoch = epoch * self._dynref_ref_ratio
                if epoch >= self._epoch + 1:
                    self._epoch = int(epoch)
                    if self._shuffle:
                        self._random_generator.shuffle(self._all_view)
                if dynref_epoch >= self._dynref_epoch + 1:
                    self._dynref_epoch = int(dynref_epoch)
                    self._dynref_rps_seen = set()

        self._curr_refims = curr_refims
        self._curr_refimids = curr_refimids
        if self._verbose:
            print(
                f'{self.__class__.__name__}{self._idx_str}: showing the following {self._nrefims} reference iamges'
            )
            print(curr_refimids)
Exemplo n.º 28
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if os.path.isdir(args.in_path):
        in_path = get_files(args.in_path)
    else:  # Single image file
        in_path = [args.in_path]

    if os.path.isdir(args.style_path):
        style_files = get_files(args.style_path)
    else:  # Single image file
        style_files = [args.style_path]

    print(style_files)
    # time.sleep(999)

    in_args = [
        'ffmpeg',
        '-i', args.in_path,
        '%s/frame_%%d.png' % in_dir
    ]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    s = time.time()
    for content_fullpath in in_path:
        content_prefix, content_ext = os.path.splitext(content_fullpath)
        content_prefix = os.path.basename(content_prefix)

        try:

            for style_fullpath in style_files:
                style_img = get_img(style_fullpath)
                if args.style_size > 0:
                    style_img = resize_to(style_img, args.style_size)
                if args.crop_size > 0:
                    style_img = center_crop(style_img, args.crop_size)

                style_prefix, _ = os.path.splitext(style_fullpath)
                style_prefix = os.path.basename(style_prefix)

                # print("ARRAY:  ", style_img)
                out_v = os.path.join(args.out_path, '{}_{}{}'.format(content_prefix, style_prefix, content_ext))
                print("OUT:", out_v)
                if os.path.isfile(out_v):
                    print("SKIP", out_v)
                    continue

                for in_f, out_f in zip(in_files, out_files):
                    print('{} -> {}'.format(in_f, out_f))
                    content_img = get_img(in_f)

                    if args.keep_colors:
                        style_rgb = preserve_colors_np(style_img, content_img)
                    else:
                        style_rgb = style_img

                    stylized = wct_model.predict(content_img, style_rgb, args.alpha, args.swap5, args.ss_alpha)

                    if args.passes > 1:
                        for _ in range(args.passes-1):
                            stylized = wct_model.predict(stylized, style_rgb, args.alpha)

                    # Stitch the style + stylized output together, but only if there's one style image
                    if args.concat:
                        # Resize style img to same height as frame
                        style_img_resized = scipy.misc.imresize(style_rgb, (stylized.shape[0], stylized.shape[0]))
                        stylized = np.hstack([style_img_resized, stylized])

                    save_img(out_f, stylized)

                fr = 30
                out_args = [
                    'ffmpeg',
                    '-i', '%s/frame_%%d.png' % out_dir,
                    '-f', 'mp4',
                    '-q:v', '0',
                    '-vcodec', 'mpeg4',
                    '-r', str(fr),
                    '"' + out_v + '"'
                ]
                print(out_args)

                subprocess.call(" ".join(out_args), shell=True)
                print('Video at: %s' % out_v)

                if args.keep_tmp is True or len(style_files) > 1:
                    continue
                else:
                    shutil.rmtree(args.tmp_dir)
                print('Processed in:', (time.time() - s))

            print('Processed in:', (time.time() - s))

        except Exception as e:
            print("EXCEPTION: ", e)
Exemplo n.º 29
0
    def process_single_data(self, text):
        """
            Process a single line of text in the data file
            INPUT:
                text: string    a line of text in the data file
                img_size: int   the output image size
                using_rotation bool if to use random rotation to make data enhancement
                kernel_size int the size of kernel in gauss kernel
                label_size: int the output label size
            OUTPUT:
                Normalized Image, Normalized Label Image, mask, Box Size, COM
                Normalized uvd groundtruth, Heatmap, Normalized Dmap
        """

        if self.process_mode == 'uvd':
            # decode the text and load the image with only hand and the uvd coordinate of joints
            _image, _joint_uvd, _com, cube_size = self.load_from_text(text)
        else: # process_mode == 'bb'
            assert self.test_only
            _image = self.load_from_text_bb(text)
            _com = None
            cube_size = None

        if _com is None:
            mean = np.mean(_image[_image > 0])
            _com = center_of_mass(_image > 0)
            _com = np.array([_com[1], _com[0], mean])

        if cube_size is None:
            cube_size = self.cube_size

        try:
            if not self.augmentation:
                raise Exception('load data without augmentation')

            image = _image.copy()
            joint_uvd = _joint_uvd.copy()
            com = _com.copy()

            if self.using_rotation:
                angle = random.random() * 60 - 30 # random rotation [-30, 30]
            else:
                angle = 0

            if self.using_scale:
                scale = 0.8 + random.random() * 0.4 # random scale [0.8, 1.2]
            else:
                scale = 1.0

            if self.using_shift:
                # random shift [-5, 5]
                shift_x = -5 + random.random() * 10 
                shift_y = -5 + random.random() * 10 
                com = self.uvd2xyz(com)
                com[0] += shift_x
                com[1] += shift_y
                com = self.xyz2uvd(com)

            # crop the image
            du = cube_size / com[2] * self.fx
            dv = cube_size / com[2] * self.fy
            box_size = int(du + dv)

            box_size = max(box_size, 2)

            crop_img = center_crop(image, (com[1], com[0]), box_size)
            crop_img = crop_img * np.logical_and(crop_img > com[2] - cube_size, crop_img < com[2] + cube_size)

            # norm the image and uvd to COM
            crop_img[crop_img > 0] -= com[2] # center the depth image to COM
            
            com[0] = int(com[0])
            com[1] = int(com[1])

            box_size = crop_img.shape[0] # update box_size

            if self.using_flip:
                if random.random() < 0.5: # probality to flip
                    for j in range(crop_img.shape[1] // 2):
                        tem = crop_img[:, j].copy()
                        crop_img[:, j] = crop_img[:, crop_img.shape[1] - j - 1].copy()
                        crop_img[:, crop_img.shape[1] - j - 1] = tem
                    joint_uvd_centered[:, 0] = - joint_uvd_centered[:, 0]

            # resize the image and uvd
            try:
                img_resize = cv2.resize(crop_img, (self.image_size, self.image_size))
            except:
                # probably because size is zero
                print("resize error")
                raise ValueError("Resize error")

            joint_uvd_centered = joint_uvd - com # center the uvd to COM
            joint_uvd_centered_resize = joint_uvd_centered.copy()
            joint_uvd_centered_resize[:,:2] = joint_uvd_centered_resize[:,:2] / (box_size - 1) * (self.image_size - 1)

            # random rotate the image and the label
            img_resize, joint_uvd_centered_resize = random_rotated(img_resize, joint_uvd_centered_resize, angle, scale)
            # change hand size
            img_resize = img_resize * scale 
            joint_uvd_centered_resize[:, 2] *= scale

            # Generate Heatmap
            joint_uvd_kernel = joint_uvd_centered_resize.copy()
            joint_uvd_kernel[:,:2] = joint_uvd_kernel[:,:2] / (self.image_size - 1) * (self.label_size - 1) + \
                np.array([self.label_size // 2, self.label_size // 2])

            # try generate heatmaps with augmented data, which may fail
            heatmaps = [generate_kernel(generate_heatmap(self.label_size, joint_uvd_kernel[i, 0], joint_uvd_kernel[i, 1]), \
                kernel_size=self.kernel_size, sigmoid=self.sigmoid)[:, :, np.newaxis] for i in range(self.joint_number)]

            # Generate label_image and mask
            label_image = cv2.resize(img_resize, (self.label_size, self.label_size))
            is_hand = label_image != 0
            mask = is_hand.astype(float)

        except: # not performing any data augmentation
            image = _image.copy()
            com = _com.copy()

            # crop the image
            du = cube_size / com[2] * self.fx
            dv = cube_size / com[2] * self.fy
            box_size = int(du + dv)
            box_size = max(box_size, 2)

            crop_img = center_crop(image, (com[1], com[0]), box_size)
            crop_img = crop_img * np.logical_and(crop_img > com[2] - cube_size, crop_img < com[2] + cube_size)

            # norm the image and uvd to COM
            crop_img[crop_img > 0] -= com[2] # center the depth image to COM
            
            com[0] = int(com[0])
            com[1] = int(com[1])
            box_size = crop_img.shape[0] # update box_size

            # resize the image and uvd
            try:
                img_resize = cv2.resize(crop_img, (self.image_size, self.image_size))
            except:
                # probably because size is zero
                print("resize error")
                raise ValueError("Resize error")

            # Generate label_image and mask
            label_image = cv2.resize(img_resize, (self.label_size, self.label_size))
            is_hand = label_image != 0
            mask = is_hand.astype(float)            

            if self.test_only:
                # Just return the basic elements we need to run the network
                # normalize the image first before return
                normalized_img = img_resize / cube_size
                normalized_label_img = label_image / cube_size

                # Convert to torch format
                normalized_img = torch.from_numpy(normalized_img).float().unsqueeze(0)
                normalized_label_img = torch.from_numpy(normalized_label_img).float().unsqueeze(0)
                mask = torch.from_numpy(mask).float().unsqueeze(0)
                box_size = torch.tensor(box_size).float()
                cube_size = torch.tensor(cube_size).float()
                com = torch.from_numpy(com).float()
                
                return normalized_img, normalized_label_img, mask, box_size, cube_size, com

            joint_uvd = _joint_uvd.copy()
            joint_uvd_centered = joint_uvd - com # center the uvd to COM
            joint_uvd_centered_resize = joint_uvd_centered.copy()
            joint_uvd_centered_resize[:,:2] = joint_uvd_centered_resize[:,:2] / (box_size - 1) * (self.image_size - 1)

            # Generate Heatmap
            joint_uvd_kernel = joint_uvd_centered_resize.copy()
            joint_uvd_kernel[:,:2] = joint_uvd_kernel[:,:2] / (self.image_size - 1) * (self.label_size - 1) + \
                np.array([self.label_size // 2, self.label_size // 2])
            try:
                heatmaps = [generate_kernel(generate_heatmap(self.label_size, joint_uvd_kernel[i, 0], joint_uvd_kernel[i, 1]), \
                    kernel_size=self.kernel_size, sigmoid=self.sigmoid)[:, :, np.newaxis] for i in range(self.joint_number)]
            except:
                path, _ = self.decode_line_txt(text)
                print("{} heatmap error".format(path))
                raise ValueError("{} heatmap error".format(path))

        heatmaps = np.concatenate(heatmaps, axis=2)

        # Generate Dmap
        Dmap = []
        for i in range(self.joint_number):
            heatmask = heatmaps[:, :, i] > 0
            heatmask = heatmask.astype(float) * mask
            Dmap.append(((joint_uvd_centered_resize[i, 2] - label_image.copy()) * heatmask)[:, :, np.newaxis])
        Dmap = np.concatenate(Dmap, axis=2)   

        # Normalize data
        normalized_img = img_resize / cube_size
        normalized_label_img = label_image / cube_size
        normalized_Dmap = Dmap / cube_size
        normalized_uvd = joint_uvd_centered_resize.copy()
        normalized_uvd[:, :2] = normalized_uvd[:, :2] / (self.image_size - 1)
        normalized_uvd[:, 2] = normalized_uvd[:, 2] / cube_size
        
        if np.any(np.isnan(normalized_img)) or np.any(np.isnan(normalized_uvd)) or \
            np.any(np.isnan(heatmaps)) or np.any(np.isnan(normalized_label_img)) or \
                np.any(np.isnan(normalized_Dmap)) or np.any(np.isnan(mask)) or np.sum(mask) < 10:
            path, data = self.decode_line_txt(text)
            print("Wired things happen, image contain Nan {}, {}".format(path, np.sum(mask)))
            raise ValueError("Wired things happen, image contain Nan {}, {}".format(path, np.sum(mask)))

        # Convert to torch format
        normalized_img = torch.from_numpy(normalized_img).float().unsqueeze(0)
        normalized_label_img = torch.from_numpy(normalized_label_img).float().unsqueeze(0)
        mask = torch.from_numpy(mask).float().unsqueeze(0)
        box_size = torch.tensor(box_size).float()
        cube_size = torch.tensor(cube_size).float()
        com = torch.from_numpy(com).float()
        normalized_uvd = torch.from_numpy(normalized_uvd).float()
        heatmaps = torch.from_numpy(heatmaps).float().permute(2, 0, 1).contiguous()
        normalized_Dmap = torch.from_numpy(normalized_Dmap).float().permute(2, 0, 1).contiguous()

        return normalized_img, normalized_label_img, mask, box_size, cube_size, com, normalized_uvd, heatmaps, normalized_Dmap        
Exemplo n.º 30
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    if not os.path.exists(os.path.join('./logs', time.strftime('%d%m'))):
        os.makedirs(os.path.join('./logs', time.strftime('%d%m')))

    gpu_config = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu)
    #with tf.Session() as sess:
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_config)) as sess:
        if FLAGS.is_train:
            fcn = FCN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,\
     dataset_name=FLAGS.dataset,is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
        else:
            fcn = EVAL(sess, batch_size=1,rgb_image_shape=[224,224,3],dataset_name=FLAGS.dataset,\
                             is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir)
        if FLAGS.is_train:
            fcn.train(FLAGS)
        else:
            VGG_mean = [103.939, 116.779, 123.68]
            train_rgb_data = open("db/Oxford_data1_RGB_train.txt")
            train_rgblist = train_rgb_data.readlines()
            train_depth_data = open("db/Oxford_data1_depth_train.txt")
            train_depthlist = train_depth_data.readlines()
            shuf = range(0, len(train_rgblist))
            random.shuffle(shuf)
            shuf = shuf[:10]
            save_files = glob.glob(
                os.path.join(FLAGS.checkpoint_dir, FLAGS.dataset,
                             'FCN.model*'))
            save_files = natsorted(save_files)
            savepath = './Depth_seg'
            if not os.path.exists(os.path.join(savepath)):
                os.makedirs(os.path.join(savepath))

            model = save_files[-2]
            model = model.split('/')
            model = model[-1]
            fcn.load(FLAGS.checkpoint_dir, model)

            for m in range(len(shuf)):
                rgbpath = train_rgblist[shuf[m]]
                rgb_img = scipy.misc.imread(rgbpath[:-1]).astype(np.float32)
                rgb_img = center_crop(rgb_img, 224)
                rgb_img = np.reshape(rgb_img, (1, 224, 224, 3))
                depthpath = train_depthlist[shuf[m]]
                depth_img = sio.loadmat(depthpath[:-1])
                depth_img = depth_img['depth']
                depth_img = np.reshape(
                    depth_img, [depth_img.shape[0], depth_img.shape[1], 1])
                depth_img = center_crop(depth_img, 224)
                start_time = time.time()
                predict = sess.run(fcn.pred_seg,
                                   feed_dict={
                                       fcn.rgb_images: rgb_img,
                                       fcn.keep_prob: 1.0
                                   })
                predict = np.squeeze(predict).astype(np.float32)

                print('time: %.8f' % (time.time() - start_time))
                if not os.path.exists(os.path.join(savepath, '%s' % (model))):
                    os.makedirs(os.path.join(savepath, '%s' % (model)))
                savename = os.path.join(
                    savepath, '%s/predict_%03d.jpg' % (model, shuf[m]))
                scipy.misc.imsave(savename, predict.astype(np.uint8))
                savename = os.path.join(savepath,
                                        '%s/gt_%03d.jpg' % (model, shuf[m]))
                scipy.misc.imsave(savename,
                                  np.squeeze(depth_img).astype(np.uint8))
Exemplo n.º 31
0
def main():
    # Load the WCT model
    wct_model = WCT(checkpoints=args.checkpoints,
                    relu_targets=args.relu_targets,
                    vgg_path=args.vgg_path,
                    device=args.device,
                    ss_patch_size=args.ss_patch_size,
                    ss_stride=args.ss_stride)

    # Create needed dirs
    in_dir = os.path.join(args.tmp_dir, 'input')
    out_dir = os.path.join(args.tmp_dir, 'sytlized')
    if not os.path.exists(in_dir):
        os.makedirs(in_dir)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    in_args = ['ffmpeg', '-i', args.in_path, '%s/frame_%%d.png' % in_dir]

    subprocess.call(" ".join(in_args), shell=True)
    base_names = os.listdir(in_dir)
    in_files = [os.path.join(in_dir, x) for x in base_names]
    out_files = [os.path.join(out_dir, x) for x in base_names]

    style_img = get_img(args.style_path)

    if args.style_size > 0:
        style_img = resize_to(style_img, args.style_size)
    if args.crop_size > 0:
        style_img = center_crop(style_img, args.crop_size)

    s = time.time()
    for in_f, out_f in zip(in_files, out_files):
        print('{} -> {}'.format(in_f, out_f))
        content_img = get_img(in_f)

        if args.keep_colors:
            style_rgb = preserve_colors_np(style_img, content_img)
        else:
            style_rgb = style_img

        stylized = wct_model.predict(content_img, style_rgb, args.alpha,
                                     args.swap5, args.ss_alpha)

        if args.passes > 1:
            for _ in range(args.passes - 1):
                stylized = wct_model.predict(stylized, style_rgb, args.alpha)

        # Stitch the style + stylized output together, but only if there's one style image
        if args.concat:
            # Resize style img to same height as frame
            style_img_resized = scipy.misc.imresize(
                style_rgb, (stylized.shape[0], stylized.shape[0]))
            stylized = np.hstack([style_img_resized, stylized])

        save_img(out_f, stylized)

    fr = 30
    out_args = [
        'ffmpeg', '-i',
        '%s/frame_%%d.png' % out_dir, '-f', 'mp4', '-q:v', '0', '-vcodec',
        'mpeg4', '-r',
        str(fr), args.out_path
    ]

    subprocess.call(" ".join(out_args), shell=True)
    print('Video at: %s' % args.out_path)

    if args.keep_tmp is False:
        shutil.rmtree(args.tmp_dir)

    print('Processed in:', (time.time() - s))
Exemplo n.º 32
0
    query_fn = cfg['qimlist'][i]
    print(str(i)+" QUERY "+query_fn); log.write(str(i)+" QUERY "+query_fn+"\n"); log.flush()
    
    # load target (query) image
    if bbxs is not None:
        target_img = img_loader(cfg['qim_fname'](cfg,i), im_size[dataset], cfg['gnd'][i]['bbx']).type(torch.cuda.FloatTensor)
    else:
        target_img = img_loader(cfg['qim_fname'](cfg,i), im_size[dataset]).type(torch.cuda.FloatTensor)

    # attack
    t = time.time()
    trials = 0
    converged = False
    while not converged and trials < max_trials:
        carrier_img = img_loader("data/input/"+carrier_fn+".jpg", im_size[dataset]).type(torch.cuda.FloatTensor)
        carrier_img = center_crop(target_img, carrier_img)
        alr = lr / divide_rate_lr**trials # reducing lr after every failure
        aiters = int(iters * multiply_rate_iters**trials) # increase iterations after every failure
        attack_img, loss_perf, loss_distort, converged = tma(train_networks, scale_factors, target_img, carrier_img, mode = mode, num_steps = aiters, lr = alr, lam = lam, sigma_blur = sigma_blur, verbose = True) 
        trials += 1

    # time and log
    total_time += time.time()-t
    log.write("performance loss  {:6f} distortion loss {:6f} total loss {:6f}\n".format(loss_perf.item(), (loss_distort).item(), (loss_distort+loss_perf).item())); log.flush()
    if trials == max_trials: print("Failed...")

    # save the attack in png file    
    savefn = output_folder+exp_name+"/"+query_fn+'.png'
    if not os.path.exists(savefn[0:savefn.rfind('/')]): os.makedirs(savefn[0:savefn.rfind('/')])
    imsave(savefn,np.transpose(attack_img.cpu().numpy(), (2,3,1,0)).squeeze())    
    print("Attack saved in "+savefn+"\n"); sys.stdout.flush()