def _input(self, inputs): with tf.variable_scope('input'): u = UpSampling2D('upsampling_u', inputs=inputs[self.KEYS.TENSOR.INPUT], size=self.config( self.KEYS.CONFIG.UPSAMPLE_RATIO))() if self.KEYS.TENSOR.LABEL in inputs: l = inputs[self.KEYS.TENSOR.LABEL] else: l = None if self.config(self.KEYS.CONFIG.INTERP): return u, None, l if SRKeys.REPRESENTS in inputs: r = UpSampling2D('upsampling_r', inputs=inputs[SRKeys.REPRESENTS], size=self.config( self.KEYS.CONFIG.UPSAMPLE_RATIO))() r = align_crop(r, u) else: r = tf.layers.conv2d(inputs=u, filters=self.config( self.KEYS.CONFIG.FILTERS), kernel_size=5, padding='same', name='stem', reuse=tf.AUTO_REUSE) return u, r, l
def kernel(self, inputs): with tf.variable_scope('input'): u = UpSampling2D(inputs=inputs[self.KEYS.TENSOR.INPUT], size=(2, 2))() if SRKeys.REPRESENTS in inputs: r = UpSampling2D(inputs=inputs[SRKeys.REPRESENTS], size=(2, 2))() r = align_crop(r, u) r = tf.concat([r, u], axis=3) else: r = tf.layers.conv2d(inputs=u, filters=self.config( self.KEYS.CONFIG.FILTERS), kernel_size=5, name='stem0') key = self.KEYS.GRAPHS.SHORT_CUT x = self.get_or_create_graph(key, self._short_cut(key))(r) with tf.variable_scope('inference'): res = tf.layers.conv2d( inputs=x, filters=1, kernel_size=3, padding='same', name='stem1', ) res = boundary_crop(res, self.config(self.KEYS.CONFIG.BOUNDARY_CROP)) u_c = align_crop(u, res) y = res + u_c result = { self.KEYS.TENSOR.INFERENCE: y, SRKeys.REPRESENTS: x, SRKeys.RESIDUAL: res, SRKeys.INTERP: u_c } if self.KEYS.TENSOR.LABEL in inputs: with tf.name_scope('loss'): aligned_label = align_crop(inputs[self.KEYS.TENSOR.LABEL], y) l = mean_square_error(aligned_label, y) result.update({ self.KEYS.TENSOR.LOSS: l, SRKeys.ALIGNED_LABEL: aligned_label }) return result
def _loss(self, label, infer): if label is not None: with tf.name_scope('loss'): align_label = align_crop(label, infer) if self.config(self.KEYS.CONFIG.USE_COMBINED_LOSS): with tf.name_scope('use_combine_loss'): stdv = tf.constant( self.config(self.KEYS.CONFIG.DENORM_STD), tf.float32) meanv = tf.constant( self.config(self.KEYS.CONFIG.DENORM_MEAN, tf.float32)) labeld = align_label * stdv + meanv inferd = infer * stdv + meanv result = CombinedSupervisedLoss(self.name / 'loss', inputs={ self.KEYS.TENSOR.INPUT: inferd, self.KEYS.TENSOR.LABEL: labeld })() result.update({SRKeys.ALIGNED_LABEL: align_label}) result.update({ self.KEYS.TENSOR.LOSS: result[self.KEYS.TENSOR.OUTPUT] }) result.pop(self.KEYS.TENSOR.OUTPUT) return result else: loss_mse = mean_square_error( align_label, infer) * self.config( self.KEYS.CONFIG.MES_LOSS_WEIGHT) loss = loss_mse if self.config(self.KEYS.CONFIG.WITH_POI_LOSS): loss_poi = poission_loss( align_label, infer) * self.config( self.KEYS.CONFIG.POI_LOSS_WEIGHT) loss = loss + loss_poi result = { self.KEYS.TENSOR.LOSS: loss, SRKeys.MSE_LOSS: loss_mse, SRKeys.ALIGNED_LABEL: align_label } return result else: return {}
def _inference(self, represents, upsampled): with tf.variable_scope('inference'): upsampled = boundary_crop(input_=upsampled, offset=self.config( self.KEYS.CONFIG.BOUNDARY_CROP)) if self.config(self.KEYS.CONFIG.INTERP): return {self.KEYS.TENSOR.INFERENCE: upsampled} residual = tf.layers.conv2d(inputs=represents, filters=1, kernel_size=3, padding='same', reuse=tf.AUTO_REUSE) residual = align_crop(input_=residual, target=upsampled) inference = residual + upsampled return { self.KEYS.TENSOR.INFERENCE: inference, SRKeys.RESIDUAL: residual, SRKeys.INTERP: upsampled }