Пример #1
0
 def smoothness_loss(true_y, pred_y):
     y0 = tf.reshape(pred_y[:, :, :, :, 0],
                     [tf.shape(pred_y)[0], *K.int_shape(pred_y)[1:4], 1])
     y1 = tf.reshape(pred_y[:, :, :, :, 1],
                     [tf.shape(pred_y)[0], *K.int_shape(pred_y)[1:4], 1])
     y2 = tf.reshape(pred_y[:, :, :, :, 2],
                     [tf.shape(pred_y)[0], *K.int_shape(pred_y)[1:4], 1])
     dx = tf.abs(volumeGradients(y0))
     dy = tf.abs(volumeGradients(y1))
     dz = tf.abs(volumeGradients(y2))
     norm = functools.reduce(lambda x, y: x * y,
                             K.int_shape(pred_y)[1:5]) * batch_size
     return tf.reduce_sum((dx + dy + dz) / norm, axis=[1, 2, 3, 4])
Пример #2
0
def volumeGradients(tf_vf):
    # batch_size, xaxis, yaxis, zaxis, depth = \
    shapes = (tf.shape(tf_vf)[0], *K.int_shape(tf_vf)[1:])
    dx = tf_vf[:, 1:, :, :, :] - tf_vf[:, :-1, :, :, :]
    dy = tf_vf[:, :, 1:, :, :] - tf_vf[:, :, :-1, :, :]
    dz = tf_vf[:, :, :, 1:, :] - tf_vf[:, :, :, :-1, :]

    # Return tensors with same size as original image by concatenating
    # zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y).
    shape = tf.stack([shapes[0], 1, shapes[2], shapes[3], shapes[4]])
    dx = array_ops.concat([dx, tf.zeros(shape, tf_vf.dtype)], 1)
    dx = array_ops.reshape(dx, tf.shape(tf_vf))

    # shape = tf.stack([batch_size, xaxis, 1, zaxis, depth])
    shape = tf.stack([shapes[0], shapes[1], 1, shapes[3], shapes[4]])
    dy = array_ops.concat([dy, array_ops.zeros(shape, tf_vf.dtype)], 2)
    dy = array_ops.reshape(dy, tf.shape(tf_vf))

    # shape = tf.stack([batch_size, xaxis, yaxis, 1, depth])
    shape = tf.stack([shapes[0], shapes[1], shapes[2], 1, shapes[4]])
    dz = array_ops.concat([dz, array_ops.zeros(shape, tf_vf.dtype)], 3)
    dz = array_ops.reshape(dz, tf.shape(tf_vf))

    return tf.reshape(array_ops.stack([dx, dy, dz], 4),
                      [shapes[0], shapes[1], shapes[2], shapes[3], 3])
Пример #3
0
def invertDisplacements(args):
    x, y, z = K.int_shape(args)[1:4]
    disp = args

    # ij indexing doesn't change (x,y,z) to (y,x,z)
    grid = tf.expand_dims(
        tf.stack(
            tf.meshgrid(tf.linspace(0., x - 1., x),
                        tf.linspace(0., y - 1., y),
                        tf.linspace(0., z - 1., z),
                        indexing='ij'), -1), 0)

    # replicate along batch size
    stacked_grids = tf.tile(grid, (tf.shape(args)[0], 1, 1, 1, 1))
    print(stacked_grids.shape)
    grids = [tf.expand_dims(stacked_grids[:, :, :, :, i], 4) for i in range(3)]
    print(grids[0].shape)
    displaced_grids = [remap3d(subgrid, disp) for subgrid in grids]
    print(displaced_grids[0].shape)
    inverted_grids = [g - disp_g for g, disp_g in zip(grids, displaced_grids)]
    print(inverted_grids[0].shape)
    inverted_grid = tf.stack(
        [tf.squeeze(inverted_grids[i], 4) for i in range(3)], 4)
    print(inverted_grid.shape)
    return inverted_grid
Пример #4
0
def avg_batch_mse_loss_exaustive(y_true, y_pred):
    batch_size = K.int_shape(y_pred)[0]
    loss = 0
    N = 0
    for i in range(0, batch_size):
        for j in range(i + 1, batch_size):
            loss += mean_squared_error(y_pred[i], y_pred[j])
            N += 1
    loss /= N
    def smoothness_loss(true_y, pred_y):
        dx = tf.abs(volumeGradients(tf.expand_dims(pred_y[:, :, :, :, 0], -1)))
        dy = tf.abs(volumeGradients(tf.expand_dims(pred_y[:, :, :, :, 1], -1)))
        dz = tf.abs(volumeGradients(tf.expand_dims(pred_y[:, :, :, :, 2], -1)))

        return tf.reduce_sum(
            (dx + dy + dz) /
            (functools.reduce(lambda x, y: x * y,
                              K.int_shape(pred_y)[1:5]) * batch_size),
            axis=[1, 2, 3, 4])
Пример #6
0
def sampling(args):
    z_mean = args[0]
    z_log_sigma = args[1]
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    #flattened_dim = functools.reduce(lambda x,y:x*y,[*dim,3])
    epsilon = tf.reshape(K.random_normal(shape=(batch, dim), dtype=tf.float32),
                         (batch, dim))
    xout = z_mean + K.exp(z_log_sigma) * epsilon
    return xout
Пример #7
0
def sampling(args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.
    # Arguments
        args (tensor): mean and log of variance of Q(z|X)
    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
Пример #8
0
 def get_output_shape_for(self, input_shape):
     if self._output_shape is None:
         # if TensorFlow, we can infer the output shape directly:
         if K._BACKEND == 'tensorflow':
             if type(input_shape) is list:
                 xs = [K.placeholder(shape=shape) for shape in input_shape]
                 x = self.call(xs)
             else:
                 x = K.placeholder(shape=input_shape)
                 x = self.call(x)
             if type(x) is list:
                 return [K.int_shape(x_elem) for x_elem in x]
             else:
                 return K.int_shape(x)
         # otherwise, we default to the input shape
         return input_shape
     elif type(self._output_shape) in {tuple, list}:
         nb_samples = input_shape[0] if input_shape else None
         return (nb_samples, ) + tuple(self._output_shape)
     else:
         shape = self._output_shape(input_shape)
         if type(shape) not in {list, tuple}:
             raise Exception('output_shape function must return a tuple')
         return tuple(shape)
Пример #9
0
def _time_distributed_dense(x,
                            w,
                            b=None,
                            dropout=None,
                            input_dim=None,
                            output_dim=None,
                            timesteps=None,
                            training=None):
    """Apply `y . w + b` for every temporal slice y of x.
    # Arguments
        x: input tensor.
        w: weight matrix.
        b: optional bias vector.
        dropout: wether to apply dropout (same dropout mask
            for every temporal slice of the input).
        input_dim: integer; optional dimensionality of the input.
        output_dim: integer; optional dimensionality of the output.
        timesteps: integer; optional number of timesteps.
        training: training phase tensor or boolean.
    # Returns
        Output tensor.
    """
    if not input_dim:
        input_dim = K.shape(x)[2]
    if not timesteps:
        timesteps = K.shape(x)[1]
    if not output_dim:
        output_dim = K.int_shape(w)[1]

    if dropout is not None and 0. < dropout < 1.:
        # apply the same dropout pattern at every timestep
        ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
        dropout_matrix = K.dropout(ones, dropout)
        expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
        x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)

    # collapse time dimension and batch dimension together
    x = K.reshape(x, (-1, input_dim))
    x = K.dot(x, w)
    if b is not None:
        x = K.bias_add(x, b)
    # reshape to 3D tensor
    if K.backend() == 'tensorflow':
        x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
        x.set_shape([None, None, output_dim])
    else:
        x = K.reshape(x, (-1, timesteps, output_dim))
    return x
Пример #10
0
    def exponentialMap(args):
        grads = args
        x, y, z = K.int_shape(args)[1:4]

        # ij indexing doesn't change (x,y,z) to (y,x,z)
        grid = tf.expand_dims(
            tf.stack(
                tf.meshgrid(tf.linspace(0., x - 1., x),
                            tf.linspace(0., y - 1., y),
                            tf.linspace(0., z - 1., z),
                            indexing='ij'), -1), 0)

        # replicate along batch size
        stacked_grids = tf.tile(grid, (tf.shape(grads)[0], 1, 1, 1, 1))

        res = tfVectorFieldExp(grads, stacked_grids, n_steps=steps)
        return res
Пример #11
0
def tfVectorFieldExp(grad, grid, n_steps):
    N = n_steps

    shapes = tf.shape(grad)
    batch_size, size_x, size_y, size_z, channels = shapes[0], *K.int_shape(
        grad)[1:5]

    id_x = tf.reshape(grid[:, :, :, :, 0],
                      [batch_size, size_x, size_y, size_z, 1])
    id_y = tf.reshape(grid[:, :, :, :, 1],
                      [batch_size, size_x, size_y, size_z, 1])
    id_z = tf.reshape(grid[:, :, :, :, 2],
                      [batch_size, size_x, size_y, size_z, 1])

    ux = grad[:, :, :, :, 0]
    uy = grad[:, :, :, :, 1]
    uz = grad[:, :, :, :, 2]

    dvx = ux / (2.0**N)
    dvy = uy / (2.0**N)
    dvz = uz / (2.0**N)
    dvx = id_x + tf.reshape(dvx, [batch_size, size_x, size_y, size_z, 1])
    dvy = id_y + tf.reshape(dvy, [batch_size, size_x, size_y, size_z, 1])
    dvz = id_z + tf.reshape(dvz, [batch_size, size_x, size_y, size_z, 1])

    for n in range(0, N - 1):
        cache_tf = tf.stack([dvx - id_x, dvy - id_y, dvz - id_z], 4)
        cache_tf = tf.reshape(cache_tf,
                              [batch_size, size_x, size_y, size_z, 3])

        dvx = remap3d(dvx, cache_tf) + tf.expand_dims(cache_tf[:, :, :, :, 0],
                                                      -1)
        dvy = remap3d(dvy, cache_tf) + tf.expand_dims(cache_tf[:, :, :, :, 1],
                                                      -1)
        dvz = remap3d(dvz, cache_tf) + tf.expand_dims(cache_tf[:, :, :, :, 2],
                                                      -1)

    ox = dvx - id_x
    oy = dvy - id_y
    oz = dvz - id_z
    out = tf.reshape(tf.stack([ox, oy, oz], 4),
                     [batch_size, size_x, size_y, size_z, 3])
    return out
Пример #12
0
def _bottleneck(inputs: "Layer",
                filters: int,
                kernel: int or Tuple[int, int],
                expansion_factor: int,
                strides: int or Tuple[int, int],
                residuals: bool = False) -> "Layer":
    """Bottleneck
    This function defines a basic bottleneck structure.
    # Arguments
        inputs: Tensor, input tensor of conv layer.
        filters: Integer, the dimensionality of the output space.
        kernel: An integer or tuple/list of 2 integers, specifying the
            width and height of the 2D convolution window.
        t: Integer, expansion factor.
            t is always applied to the input size.
        s: An integer or tuple/list of 2 integers,specifying the strides
            of the convolution along the width and height.Can be a single
            integer to specify the same value for all spatial dimensions.
        r: Boolean, Whether to use the residuals.
    # Returns
        Output tensor.
    """

    tchannel = K.int_shape(inputs)[1] * expansion_factor

    layer = _conv_block(inputs, tchannel, (1, 1), (1, 1))

    layer = DepthwiseConv2D(kernel,
                            strides=(strides, strides),
                            depth_multiplier=1,
                            padding='same')(layer)
    layer = BatchNormalization()(layer)
    layer = relu6(layer)

    layer = Conv2D(filters, (1, 1), strides=(1, 1), padding='same')(layer)
    layer = BatchNormalization()(layer)

    if residuals:
        layer = add([layer, inputs])
    return layer
    def exponentialMap(args):
        velo_raw = args
        x, y, z = K.int_shape(args)[1:4]

        # clip too large values:
        v_max = 0.5 * (2**steps)
        v_min = -v_max
        velo = tf.clip_by_value(velo_raw, v_min, v_max)

        # ij indexing doesn't change (x,y,z) to (y,x,z)
        grid = tf.expand_dims(
            tf.stack(
                tf.meshgrid(tf.linspace(0., x - 1., x),
                            tf.linspace(0., y - 1., y),
                            tf.linspace(0., z - 1., z),
                            indexing='ij'), -1), 0)

        # replicate along batch size
        stacked_grids = tf.tile(grid, (tf.shape(velo)[0], 1, 1, 1, 1))

        res = tfVectorFieldExp(velo, stacked_grids, n_steps=steps)
        return res
Пример #14
0
    def build(self,
              block_config: tuple,
              bottleneck: bool,
              name: str = 'ResNet'):
        """
        build a ResNet model
        :param block_config: number blocks in stage 2 ~ 5
        :param bottleneck: whether to use a bottleneck residual block
        :param name: model name
        :return: model
        """
        self.bottleneck = bottleneck

        img_input = Input(shape=self.input_shape)

        # Stage 1
        x = self.Conv2D(64, kernel_size=7, strides=2,
                        padding='same')(img_input)
        x = self.BatchNormalization()(x)
        x = Activation('relu')(x)
        x = MaxPooling2D(pool_size=3, strides=2)(x)

        # Stage 2 ~ 5
        x = self.apply_residual_stage(x, 2, block_config[0])
        x = self.apply_residual_stage(x, 3, block_config[1])
        x = self.apply_residual_stage(x, 4, block_config[2])
        x = self.apply_residual_stage(x, 5, block_config[3])

        x = self.BatchNormalization()(x)
        x = Activation('relu')(x)

        shape = K.int_shape(x)
        x = AveragePooling2D(pool_size=(shape[1], shape[2]), strides=1)(x)

        x = Flatten()(x)
        x = self.Dense(self.classes, 'softmax')(x)

        return Model(inputs=img_input, outputs=x, name=name)
Пример #15
0
    def get_constants(self, inputs, training=None):
        constants = []
        if self.implementation != 0 and 0 < self.dropout < 1:
            input_shape = K.int_shape(inputs)
            input_dim = input_shape[-1]
            ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
            ones = K.tile(ones, (1, int(input_dim)))

            def dropped_inputs():
                return K.dropout(ones, self.dropout)

            dp_mask = [
                K.in_train_phase(dropped_inputs, ones, training=training)
                for _ in range(4)
            ]
            constants.append(dp_mask)
        else:
            constants.append([K.cast_to_floatx(1.) for _ in range(4)])

        if 0 < self.recurrent_dropout < 1:
            ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
            ones = K.tile(ones, (1, self.units))

            def dropped_inputs():
                return K.dropout(ones, self.recurrent_dropout)

            rec_dp_mask = [
                K.in_train_phase(dropped_inputs, ones, training=training)
                for _ in range(4)
            ]
            constants.append(rec_dp_mask)
        else:
            constants.append([K.cast_to_floatx(1.) for _ in range(4)])

        # append the input as well for use later
        constants.append(inputs)
        return constants
Пример #16
0
    def bottleneck_encoder(self,
                           tensor,
                           nfilters,
                           downsampling=False,
                           dilated=False,
                           asymmetric=False,
                           normal=False,
                           drate=0.1,
                           name=''):
        """
        Encoder

        :param tensor: input tensor
        :param nfilters: Number of filters
        :param downsampling: Downsample the feature map
        :param dilated: determines  if ther should be dilated convultion
        :param asymmetric:  Determines if there should be asymmetric convolution
        :param normal:  enables 3x3 convolution on feature map
        :param drate: rate of dilation
        :param name: the name for the weight variable.
        :return: encoder output
        """
        y = tensor
        skip = tensor
        stride = 1
        ksize = 1

        # Filters operating on downsampled images have a bigger receptive field and hence gathers more context.
        if downsampling:
            stride = 2
            ksize = 2
            skip = MaxPooling2D(pool_size=(2, 2),
                                name=f'max_pool_{name}')(skip)
            skip = Permute(
                (1, 3, 2),
                name=f'permute_1_{name}')(skip)  # (B, H, W, C) -> (B, H, C, W)
            ch_pad = nfilters - K.int_shape(tensor)[-1]
            skip = ZeroPadding2D(padding=((0, 0), (0, ch_pad)),
                                 name=f'zeropadding_{name}')(skip)
            skip = Permute(
                (1, 3, 2),
                name=f'permute_2_{name}')(skip)  # (B, H, C, W) -> (B, H, W, C)

        y = Conv2D(filters=nfilters // 4,
                   kernel_size=(ksize, ksize),
                   kernel_initializer='he_normal',
                   strides=(stride, stride),
                   padding='same',
                   use_bias=False,
                   name=f'1x1_conv_{name}')(y)
        y = BatchNormalization(momentum=0.1, name=f'bn_1x1_{name}')(y)
        y = PReLU(shared_axes=[1, 2], name=f'prelu_1x1_{name}')(y)

        if normal:
            # deconv with 3x3 filter
            y = Conv2D(filters=nfilters // 4,
                       kernel_size=(3, 3),
                       kernel_initializer='he_normal',
                       padding='same',
                       name=f'3x3_conv_{name}')(y)
        elif asymmetric:
            # decompose 5x5 convolution to two asymmetric layers as 5x1 and 1x5
            y = Conv2D(filters=nfilters // 4,
                       kernel_size=(5, 1),
                       kernel_initializer='he_normal',
                       padding='same',
                       use_bias=False,
                       name=f'5x1_conv_{name}')(y)
            y = Conv2D(filters=nfilters // 4,
                       kernel_size=(1, 5),
                       kernel_initializer='he_normal',
                       padding='same',
                       name=f'1x5_conv_{name}')(y)
        elif dilated:
            y = Conv2D(filters=nfilters // 4,
                       kernel_size=(3, 3),
                       kernel_initializer='he_normal',
                       dilation_rate=(dilated, dilated),
                       padding='same',
                       name=f'dilated_conv_{name}')(y)
        y = BatchNormalization(momentum=0.1, name=f'bn_main_{name}')(y)
        y = PReLU(shared_axes=[1, 2], name=f'prelu_{name}')(y)

        y = Conv2D(filters=nfilters,
                   kernel_size=(1, 1),
                   kernel_initializer='he_normal',
                   use_bias=False,
                   name=f'final_1x1_{name}')(y)
        y = BatchNormalization(momentum=0.1, name=f'bn_final_{name}')(y)
        y = SpatialDropout2D(rate=drate,
                             name=f'spatial_dropout_final_{name}')(y)

        y = Add(name=f'add_{name}')([y, skip])
        y = PReLU(shared_axes=[1, 2], name=f'prelu_out_{name}')(y)

        return y
Пример #17
0
    def step(self, inputs, states):
        h_tm1 = states[0]
        c_tm1 = states[1]
        dp_mask = states[2]
        rec_dp_mask = states[3]
        x_input = states[4]

        # alignment model
        h_att = K.repeat(h_tm1, self.timestep_dim)
        att = _time_distributed_dense(x_input,
                                      self.attention_weights,
                                      self.attention_bias,
                                      output_dim=K.int_shape(
                                          self.attention_weights)[1])
        attention_ = self.attention_activation(
            K.dot(h_att, self.attention_recurrent_weights) + att)
        attention_ = K.squeeze(
            K.dot(attention_, self.attention_recurrent_bias), 2)

        alpha = K.exp(attention_)

        if dp_mask is not None:
            alpha *= dp_mask[0]

        alpha /= K.sum(alpha, axis=1, keepdims=True)
        alpha_r = K.repeat(alpha, self.input_dim)
        alpha_r = K.permute_dimensions(alpha_r, (0, 2, 1))

        # make context vector (soft attention after Bahdanau et al.)
        z_hat = x_input * alpha_r
        context_sequence = z_hat
        z_hat = K.sum(z_hat, axis=1)

        if self.implementation == 2:
            z = K.dot(inputs * dp_mask[0], self.kernel)
            z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel)
            z += K.dot(z_hat, self.attention_kernel)

            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units:2 * self.units]
            z2 = z[:, 2 * self.units:3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)
        else:
            if self.implementation == 0:
                x_i = inputs[:, :self.units]
                x_f = inputs[:, self.units:2 * self.units]
                x_c = inputs[:, 2 * self.units:3 * self.units]
                x_o = inputs[:, 3 * self.units:]
            elif self.implementation == 1:
                x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
                x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
                x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
                x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
            else:
                raise ValueError('Unknown `implementation` mode.')

            i = self.recurrent_activation(
                x_i + K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) +
                K.dot(z_hat, self.attention_i))
            f = self.recurrent_activation(
                x_f + K.dot(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) +
                K.dot(z_hat, self.attention_f))
            c = f * c_tm1 + i * self.activation(
                x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) +
                K.dot(z_hat, self.attention_c))
            o = self.recurrent_activation(
                x_o + K.dot(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) +
                K.dot(z_hat, self.attention_o))
        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            h._uses_learning_phase = True

        if self.return_attention:
            return context_sequence, [h, c]
        else:
            return h, [h, c]
Пример #18
0
def generateAvgFromVolumes(vol_center, volumes, model):
    session = tf.Session()

    model_config = {
        'batchsize': 1,
        'split': 0.9,
        'validation': 0.1,
        'half_res': True,
        'epochs': 200,
        'groupnorm': True,
        'GN_groups': 32,
        'atlas': 'atlas.nii.gz',
        'model_output': 'model.pkl',
        'exponentialSteps': 7,
    }

    atlas, itk_atlas = DataGenerator.loadAtlas(model_config)

    m = DiffeomorphicRegistrationNet.create_model(model_config)
    m.load_weights(model)
    shapes = atlas.squeeze().shape

    print("First is : {}".format(vol_center))
    vol_first = vol_center
    np_vol_center = readNormalizedVolumeByPath(vol_first, itk_atlas).reshape(
        1, *shapes).astype(np.float32)

    velocities = []
    for vol in volumes:
        #np_atlas = atlas.reshape(1,*shapes).astype(np.float32)
        np_vol = readNormalizedVolumeByPath(vol, itk_atlas).reshape(
            1, *shapes).astype(np.float32)

        np_stack = np.empty(1 * shapes[0] * shapes[1] * shapes[2] * 2,
                            dtype=np.float32).reshape(1, *shapes, 2)
        np_stack[:, :, :, :, 0] = np_vol
        np_stack[:, :, :, :, 1] = np_vol_center

        #tf_stack = tf.convert_to_tensor(np_stack)
        predictions = m.predict(np_stack)
        velocity = predictions[2][0, :, :, :, :]
        velocities.append(velocity)

    # compute avg velocities
    avg_velocity = np.zeros(
        int(1 * shapes[0] / 2 * shapes[1] / 2 * shapes[2] / 2 * 3),
        dtype=np.float32).reshape(1, *[int(s / 2) for s in shapes], 3)
    for v in velocities:
        avg_velocity += v
    avg_velocity /= float(len(velocities))

    # apply squaring&scaling
    steps = model_config['exponentialSteps']
    tf_velo = tf.convert_to_tensor(
        avg_velocity.reshape(1, *[int(s / 2) for s in shapes], 3))
    tf_vol_center = tf.convert_to_tensor(np_vol_center.reshape(1, *shapes, 1))

    x, y, z = K.int_shape(tf_velo)[1:4]

    # clip too large values:
    v_max = 0.5 * (2**steps)
    v_min = -v_max
    velo = tf.clip_by_value(tf_velo, v_min, v_max)

    # ij indexing doesn't change (x,y,z) to (y,x,z)
    grid = tf.expand_dims(
        tf.stack(
            tf.meshgrid(tf.linspace(0., x - 1., x),
                        tf.linspace(0., y - 1., y),
                        tf.linspace(0., z - 1., z),
                        indexing='ij'), -1), 0)

    # replicate along batch size
    stacked_grids = tf.tile(grid, (tf.shape(velo)[0], 1, 1, 1, 1))

    displacement = tfVectorFieldExpHalf(velo, stacked_grids, n_steps=steps)
    displacement_highres = toUpscaleResampled(displacement)
    # warp center volume
    new_warped = remap3d(tf_vol_center, displacement_highres)
    with session.as_default():
        new_volume = new_warped.eval(session=session).reshape(*shapes)

    vol_dirs = np.array(itk_atlas.GetDirection()).reshape(3, 3)
    # reapply directions
    warp_np = np.flip(new_volume,
                      [a for a in range(3) if vol_dirs[a, a] == -1.])
    # prepare axes swap from xyz to zyx
    warp_np = np.transpose(warp_np, (2, 1, 0))
    # write image
    warp_img = sitk.GetImageFromArray(warp_np)
    warp_img.SetOrigin(itk_atlas.GetOrigin())
    warp_img.SetDirection(itk_atlas.GetDirection())
    sitk.WriteImage(warp_img, "new_volume.nii.gz")