-
Notifications
You must be signed in to change notification settings - Fork 16
/
sr_dataset.py
138 lines (121 loc) · 5.65 KB
/
sr_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
__author__ = 'Sherwin'
import numpy as np
from sr_util import sr_image_util
from sklearn.neighbors import NearestNeighbors
import multiprocessing
from multiprocessing import Process, Manager
DEFAULT_PYRAMID_LEVEL = 3
DEFAULT_DOWNGRADE_RATIO = 2 ** (1.0/3)
DEFAULT_NEIGHBORS = 9
class SRDataSet(object):
def __init__(self, low_res_patches, high_res_patches, neighbors=DEFAULT_NEIGHBORS):
self._low_res_patches = low_res_patches
self._high_res_patches = high_res_patches
self._nearest_neighbor = None
self._neighbors = neighbors
self._need_update = True
self._update()
@classmethod
def from_sr_image(cls, sr_image, pyramid_level=DEFAULT_PYRAMID_LEVEL, downgrade_ratio=DEFAULT_DOWNGRADE_RATIO):
"""Create a SRDataset object from a SRImage object.
@param sr_image:
@type sr_image: L{sr_image.SRImage}
@return: SRDataset object
@rtype: L{sr_dataset.SRDataset}
"""
high_res_patches = sr_image_util.get_patches_without_dc(sr_image)
sr_dataset = None
for downgraded_sr_image in sr_image.get_pyramid(pyramid_level, downgrade_ratio):
low_res_patches = sr_image_util.get_patches_without_dc(downgraded_sr_image)
if sr_dataset is None:
sr_dataset = SRDataSet(low_res_patches, high_res_patches)
else:
sr_dataset.add(low_res_patches, high_res_patches)
return sr_dataset
@property
def low_res_patches(self):
return self._low_res_patches
@property
def high_res_patches(self):
return self._high_res_patches
def _update(self):
self._nearest_neighbor = NearestNeighbors(n_neighbors=self._neighbors,
algorithm='kd_tree').fit(self._low_res_patches)
self._need_update = False
def add(self, low_res_patches, high_res_patches):
"""Add low_res_patches -> high_res_patches mapping to the dataset.
@param low_res_patches: low resolution patches
@type low_res_patches: L{numpy.array}
@param high_res_patches: high resolution patches
@type high_res_patches: L{numpy.array}
"""
self._low_res_patches = np.concatenate((self._low_res_patches, low_res_patches))
self._high_res_patches = np.concatenate((self._high_res_patches, high_res_patches))
self._need_update = True
def merge(self, sr_dataset):
"""Merge with the given dataset.
@param sr_dataset: an instance of SRDataset
@type sr_dataset: L{sr_dataset.SRDataset}
"""
low_res_patches = sr_dataset.low_res_patches
high_res_patches = sr_dataset.high_res_patches
self.add(low_res_patches, high_res_patches)
def parallel_query(self, low_res_patches):
"""Query the high resolution patches for the given low resolution patches using
multiprocessing.
@param low_res_patches: given low resolution patches
@type low_res_patches: L{numpy.array}
@return: high resolution patches in row vector form
@rtype: L{numpy.array}
"""
if self._need_update:
self._update()
cpu_count = multiprocessing.cpu_count()
patch_number, patch_dimension = np.shape(low_res_patches)
batch_number = patch_number / cpu_count + 1
jobs = []
result = Manager().dict()
for id in range(cpu_count):
batch = low_res_patches[id*batch_number:(id+1)*batch_number, :]
job = Process(target=self.query, args=(batch, id, result))
jobs.append(job)
job.start()
for job in jobs:
job.join()
high_res_patches = np.concatenate(result.values())
return high_res_patches
def query(self, low_res_patches, id=1, result=None):
"""Query the high resolution patches for the given low resolution patches.
@param low_res_patches: low resolution patches
@type low_res_patches: L{numpy.array}
@param id: id for subprocess, used for multiprocessing
@type id: int
@param result: shared dict between processes, used for multiprocessing
@type: L{multiprocessing.Manager.dict}
@return: high resolution patches for the given low resolution patches
@rtype: L{numpy.array}
"""
if self._need_update:
self._update()
distances, indices = self._nearest_neighbor.kneighbors(low_res_patches,
n_neighbors=self._neighbors)
neighbor_patches = self.high_res_patches[indices]
high_res_patches = self._merge_high_res_patches(neighbor_patches, distances) if \
self._neighbors > 1 else neighbor_patches
if result is not None:
result[id] = high_res_patches
return high_res_patches
def _merge_high_res_patches(self, neighbor_patches, distances):
"""Get the high resolution patches by merging the neighboring patches with the given distance as weight.
@param neighbor_patches: neighboring high resolution patches
@type neighbor_patches: L{numpy.array}
@param distances: distance vector associate with the neighboring patches
@type distances: L{numpy.array}
@return: high resolution patches by merging the neighboring patches
@rtype: L{numpy.array}
"""
patch_number, neighbor_number, patch_dimension = np.shape(neighbor_patches)
weights = sr_image_util.normalize(np.exp(-0.25*distances))
weights = weights[:, np.newaxis].reshape(patch_number, neighbor_number, 1)
high_res_patches = np.sum(neighbor_patches*weights, axis=1)
return high_res_patches