def _input(self, inputs):
        with tf.variable_scope('input'):
            u = UpSampling2D('upsampling_u',
                             inputs=inputs[self.KEYS.TENSOR.INPUT],
                             size=self.config(
                                 self.KEYS.CONFIG.UPSAMPLE_RATIO))()

            if self.KEYS.TENSOR.LABEL in inputs:
                l = inputs[self.KEYS.TENSOR.LABEL]
            else:
                l = None

            if self.config(self.KEYS.CONFIG.INTERP):
                return u, None, l

            if SRKeys.REPRESENTS in inputs:
                r = UpSampling2D('upsampling_r',
                                 inputs=inputs[SRKeys.REPRESENTS],
                                 size=self.config(
                                     self.KEYS.CONFIG.UPSAMPLE_RATIO))()
                r = align_crop(r, u)
            else:
                r = tf.layers.conv2d(inputs=u,
                                     filters=self.config(
                                         self.KEYS.CONFIG.FILTERS),
                                     kernel_size=5,
                                     padding='same',
                                     name='stem',
                                     reuse=tf.AUTO_REUSE)

            return u, r, l
    def kernel(self, inputs):
        with tf.variable_scope('input'):
            u = UpSampling2D(inputs=inputs[self.KEYS.TENSOR.INPUT],
                             size=(2, 2))()
            if SRKeys.REPRESENTS in inputs:
                r = UpSampling2D(inputs=inputs[SRKeys.REPRESENTS],
                                 size=(2, 2))()
                r = align_crop(r, u)
                r = tf.concat([r, u], axis=3)
            else:
                r = tf.layers.conv2d(inputs=u,
                                     filters=self.config(
                                         self.KEYS.CONFIG.FILTERS),
                                     kernel_size=5,
                                     name='stem0')

        key = self.KEYS.GRAPHS.SHORT_CUT
        x = self.get_or_create_graph(key, self._short_cut(key))(r)
        with tf.variable_scope('inference'):
            res = tf.layers.conv2d(
                inputs=x,
                filters=1,
                kernel_size=3,
                padding='same',
                name='stem1',
            )
            res = boundary_crop(res,
                                self.config(self.KEYS.CONFIG.BOUNDARY_CROP))
            u_c = align_crop(u, res)
            y = res + u_c

        result = {
            self.KEYS.TENSOR.INFERENCE: y,
            SRKeys.REPRESENTS: x,
            SRKeys.RESIDUAL: res,
            SRKeys.INTERP: u_c
        }
        if self.KEYS.TENSOR.LABEL in inputs:
            with tf.name_scope('loss'):
                aligned_label = align_crop(inputs[self.KEYS.TENSOR.LABEL], y)
                l = mean_square_error(aligned_label, y)
            result.update({
                self.KEYS.TENSOR.LOSS: l,
                SRKeys.ALIGNED_LABEL: aligned_label
            })

        return result
 def _loss(self, label, infer):
     if label is not None:
         with tf.name_scope('loss'):
             align_label = align_crop(label, infer)
             if self.config(self.KEYS.CONFIG.USE_COMBINED_LOSS):
                 with tf.name_scope('use_combine_loss'):
                     stdv = tf.constant(
                         self.config(self.KEYS.CONFIG.DENORM_STD),
                         tf.float32)
                     meanv = tf.constant(
                         self.config(self.KEYS.CONFIG.DENORM_MEAN,
                                     tf.float32))
                     labeld = align_label * stdv + meanv
                     inferd = infer * stdv + meanv
                 result = CombinedSupervisedLoss(self.name / 'loss',
                                                 inputs={
                                                     self.KEYS.TENSOR.INPUT:
                                                     inferd,
                                                     self.KEYS.TENSOR.LABEL:
                                                     labeld
                                                 })()
                 result.update({SRKeys.ALIGNED_LABEL: align_label})
                 result.update({
                     self.KEYS.TENSOR.LOSS:
                     result[self.KEYS.TENSOR.OUTPUT]
                 })
                 result.pop(self.KEYS.TENSOR.OUTPUT)
                 return result
             else:
                 loss_mse = mean_square_error(
                     align_label, infer) * self.config(
                         self.KEYS.CONFIG.MES_LOSS_WEIGHT)
                 loss = loss_mse
                 if self.config(self.KEYS.CONFIG.WITH_POI_LOSS):
                     loss_poi = poission_loss(
                         align_label, infer) * self.config(
                             self.KEYS.CONFIG.POI_LOSS_WEIGHT)
                     loss = loss + loss_poi
                 result = {
                     self.KEYS.TENSOR.LOSS: loss,
                     SRKeys.MSE_LOSS: loss_mse,
                     SRKeys.ALIGNED_LABEL: align_label
                 }
                 return result
     else:
         return {}
    def _inference(self, represents, upsampled):
        with tf.variable_scope('inference'):
            upsampled = boundary_crop(input_=upsampled,
                                      offset=self.config(
                                          self.KEYS.CONFIG.BOUNDARY_CROP))
            if self.config(self.KEYS.CONFIG.INTERP):
                return {self.KEYS.TENSOR.INFERENCE: upsampled}
            residual = tf.layers.conv2d(inputs=represents,
                                        filters=1,
                                        kernel_size=3,
                                        padding='same',
                                        reuse=tf.AUTO_REUSE)
            residual = align_crop(input_=residual, target=upsampled)
            inference = residual + upsampled

        return {
            self.KEYS.TENSOR.INFERENCE: inference,
            SRKeys.RESIDUAL: residual,
            SRKeys.INTERP: upsampled
        }