Exemple #1
0
def apply_gcc_weights_c(ks, cc_mat):
    """Apply coil compression weights.

    Input
      ks -- raw k-space data of dimensions (num_channels, num_readout, num_kx)
      cc_mat -- coil compression matrix calculated using calc_gcc_weights
    Output
      ks_out -- coil compresssed data
    """
    me = "coilcomp.apply_gcc_weights_c"

    num_channels = ks.shape[0]
    num_readout = ks.shape[1]
    num_kx = ks.shape[2]
    num_virtual_channels = cc_mat.shape[0]

    if num_channels != cc_mat.shape[1]:
        print("%s> ERROR! num channels does not match!" % me)
        print("%s>   ks: num channels = %d" % (me, num_channels))
        print("%s>   cc_mat: num channels = %d" % (me, cc_mat.shape[1]))

    ks_x = fftc.ifftc(ks, axis=-1)
    ks_out = np.zeros((num_virtual_channels, num_readout, num_kx),
                      dtype=np.complex64)
    for i_channel in range(num_virtual_channels):
        cc_mat_i = np.reshape(cc_mat[i_channel, :, :],
                              (num_channels, 1, num_kx))
        ks_out[i_channel, :, :] = np.sum(ks_x * cc_mat_i, axis=0)
    ks_out = fftc.fftc(ks_out, axis=-1)

    return ks_out
Exemple #2
0
def calc_gcc_weights(ks_calib, num_virtual_channels, correction=True):
    """Calculate coil compression weights.

    Input
      ks_calib -- raw k-space data of dimensions (num_kx, num_readout, num_channels)
      num_virtual_channels -- number of virtual channels to compress to
      correction -- apply rotation correction (default: True)
    Output
      cc_mat -- coil compression matrix (use apply_gcc_weights)
    """

    me = "coilcomp.calc_gcc_weights"

    num_kx = ks_calib.shape[0]
    # num_readout = ks_calib.shape[1]
    num_channels = ks_calib.shape[2]

    if num_virtual_channels > num_channels:
        print(
            "%s> Num of virtual channels (%d) is more than the actual channels (%d)!"
            % (me, num_virtual_channels, num_channels))
        return np.eye(num_channels, dtype=complex)

    # find max in readout
    tmp = np.sum(np.sum(np.power(np.abs(ks_calib), 2), axis=2), axis=1)
    i_xmax = np.argmax(tmp)
    # circ shift to move max to center (make copy to not touch original data)
    ks_calib_int = np.roll(ks_calib.copy(), int(num_kx / 2 - i_xmax), axis=0)
    ks_calib_int = fftc.ifftc(ks_calib_int, axis=0)

    cc_mat = np.zeros((num_kx, num_channels, num_virtual_channels),
                      dtype=complex)
    for i_x in range(num_kx):
        ks_calib_x = np.squeeze(ks_calib_int[i_x, :, :])
        U, s, Vh = np.linalg.svd(ks_calib_x, full_matrices=False)
        V = Vh.conj().T
        cc_mat[i_x, :, :] = V[:, 0:num_virtual_channels]

    if correction:
        for i_x in range(int(num_kx / 2) - 2, -1, -1):
            V1 = cc_mat[i_x + 1, :, :]
            V2 = cc_mat[i_x, :, :]
            A = np.matmul(V1.conj().T, V2)
            Ua, sa, Vah = np.linalg.svd(A, full_matrices=False)
            P = np.matmul(Ua, Vah)
            P = P.conj().T
            cc_mat[i_x, :, :] = np.matmul(cc_mat[i_x, :, :], P)

        for i_x in range(int(num_kx / 2) - 1, num_kx, 1):
            V1 = cc_mat[i_x - 1, :, :]
            V2 = cc_mat[i_x, :, :]
            A = np.matmul(V1.conj().T, V2)
            Ua, sa, Vah = np.linalg.svd(A, full_matrices=False)
            P = np.matmul(Ua, Vah)
            P = P.conj().T
            cc_mat[i_x, :, :] = np.matmul(np.squeeze(cc_mat[i_x, :, :]), P)

    return cc_mat
def setup_data_tfrecords(
    dir_in_root,
    dir_out,
    data_divide=(0.75, 0.05, 0.2),
    min_shape=[80, 180],
    num_maps=1,
    crop_maps=False,
    verbose=False,
):
    """Setups training data as tfrecords.

    prep_data.setup_data('/mnt/raid3/data/Studies_DCE/recon-ccomp6/',
        '/mnt/raid3/jycheng/Project/deepspirit/data/train/', verbose=True)
    """

    # Check for two echos in here
    # Use glob to find if have echo01

    if verbose:
        print("Directory names:")
        print("  Input root:  %s" % dir_in_root)
        print("  Output root: %s" % dir_out)

    file_kspace = "kspace"
    file_sensemap = "sensemap"

    case_list = os.listdir(dir_in_root)
    random.shuffle(case_list)
    num_cases = len(case_list)

    i_train_1 = np.round(data_divide[0] * num_cases).astype(int)
    i_validate_0 = i_train_1 + 1
    i_validate_1 = np.round(
        data_divide[1] * num_cases).astype(int) + i_validate_0

    if not os.path.exists(dir_out):
        os.mkdir(dir_out)
    if not os.path.exists(os.path.join(dir_out, "train")):
        os.mkdir(os.path.join(dir_out, "train"))
    if not os.path.exists(os.path.join(dir_out, "validate")):
        os.mkdir(os.path.join(dir_out, "validate"))
    if not os.path.exists(os.path.join(dir_out, "test")):
        os.mkdir(os.path.join(dir_out, "test"))

    i_case = 0
    for case_name in case_list:
        file_kspace_i = os.path.join(dir_in_root, case_name, file_kspace)
        file_sensemap_i = os.path.join(dir_in_root, case_name, file_sensemap)

        if i_case < i_train_1:
            dir_out_i = os.path.join(dir_out, "train")
        elif i_case < i_validate_1:
            dir_out_i = os.path.join(dir_out, "validate")
        else:
            dir_out_i = os.path.join(dir_out, "test")

        if verbose:
            print("Processing [%d] %s..." % (i_case, case_name))
        i_case = i_case + 1

        kspace = np.squeeze(cfl.read(file_kspace_i))
        if (min_shape is None) or (
            min_shape[0] <= kspace.shape[1] and min_shape[1] <= kspace.shape[2]
        ):
            if verbose:
                print("  Slice shape: (%d, %d)" %
                      (kspace.shape[1], kspace.shape[2]))
                print("  Num channels: %d" % kspace.shape[0])
            shape_x = kspace.shape[-1]
            kspace = fftc.ifftc(kspace, axis=-1)
            kspace = kspace.astype(np.complex64)

            # if shape_c_out < shape_c:
            #     if verbose:
            #         print("  applying coil compression (%d -> %d)..." %
            #               (shape_c, shape_c_out))
            #     shape_cal = 24
            #     ks_cal = recon.crop(ks, [-1, shape_cal, shape_cal, -1])
            #     ks_cal = np.reshape(ks_cal, [shape_c,
            #                                  shape_cal*shape_cal,
            #                                  shape_x])
            #     cc_mat = coilcomp.calc_gcc_weights_c(ks_cal, shape_c_out)
            #     ks_cc = np.reshape(ks, [shape_c, -1, shape_x])
            #     ks_cc = coilcomp.apply_gcc_weights_c(ks_cc, cc_mat)
            #     ks = np.reshape(ks_cc, [shape_c_out, shape_z, shape_y, shape_x])

            cmd_flags = ""
            if crop_maps:
                cmd_flags = cmd_flags + " -c 1e-9"
            cmd_flags = cmd_flags + (" -m %d" % num_maps)
            cmd = "%s ecalib %s %s %s" % (
                BIN_BART,
                cmd_flags,
                file_kspace_i,
                file_sensemap_i,
            )
            if verbose:
                print("  Estimating sensitivity maps (bart espirit)...")
                print("    %s" % cmd)
            subprocess.check_call(["bash", "-c", cmd])
            sensemap = np.squeeze(cfl.read(file_sensemap_i))
            sensemap = np.expand_dims(sensemap, axis=0)
            sensemap = sensemap.astype(np.complex64)

            if verbose:
                print("  Creating tfrecords (%d)..." % shape_x)
            for i_x in range(shape_x):
                file_out = os.path.join(
                    dir_out_i, "%s_x%03d.tfrecords" % (case_name, i_x)
                )
                kspace_x = kspace[:, :, :, i_x]
                sensemap_x = sensemap[:, :, :, :, i_x]

                example = tf.train.Example(
                    features=tf.train.Features(
                        feature={
                            "name": _bytes_feature(str.encode(case_name)),
                            "xslice": _int64_feature(i_x),
                            "ks_shape_x": _int64_feature(kspace.shape[3]),
                            "ks_shape_y": _int64_feature(kspace.shape[2]),
                            "ks_shape_z": _int64_feature(kspace.shape[1]),
                            "ks_shape_c": _int64_feature(kspace.shape[0]),
                            "map_shape_x": _int64_feature(sensemap.shape[4]),
                            "map_shape_y": _int64_feature(sensemap.shape[3]),
                            "map_shape_z": _int64_feature(sensemap.shape[2]),
                            "map_shape_c": _int64_feature(sensemap.shape[1]),
                            "map_shape_m": _int64_feature(sensemap.shape[0]),
                            "ks": _bytes_feature(kspace_x.tostring()),
                            "map": _bytes_feature(sensemap_x.tostring()),
                        }
                    )
                )

                tf_writer = tf.python_io.TFRecordWriter(file_out)
                tf_writer.write(example.SerializeToString())
                tf_writer.close()
Exemple #4
0
def setup_data_tfrecords(
    dir_in_root,
    dir_out,
    data_divide=(0.75, 0.05, 0.2),
    min_shape=[80, 180],
    num_maps=1,
    crop_maps=False,
    verbose=False,
):
    """Setups training data as tfrecords.
    prep_data.setup_data('/mnt/raid3/data/Studies_DCE/recon-ccomp6/',
        '/mnt/raid3/jycheng/Project/deepspirit/data/train/', verbose=True)
    """
    if verbose:
        print("Directory names:")
        print("  Input root:  %s" % dir_in_root)
        print("  Output root: %s" % dir_out)

    file_kspace = "kspace"
    file_sensemap = "sensemap"

    case_list = os.listdir(dir_in_root)
    random.shuffle(case_list)
    num_cases = len(case_list)

    i_train_1 = np.round(data_divide[0] * num_cases).astype(int)
    i_validate_0 = i_train_1 + 1
    i_validate_1 = np.round(
        data_divide[1] * num_cases).astype(int) + i_validate_0

    if not os.path.exists(dir_out):
        os.mkdir(dir_out)
    if not os.path.exists(os.path.join(dir_out, "train")):
        os.mkdir(os.path.join(dir_out, "train"))
    if not os.path.exists(os.path.join(dir_out, "validate")):
        os.mkdir(os.path.join(dir_out, "validate"))
    if not os.path.exists(os.path.join(dir_out, "test")):
        os.mkdir(os.path.join(dir_out, "test"))

    i_case = 0
    for case_name in case_list:
        file_kspace_i = os.path.join(dir_in_root, case_name, file_kspace)
        file_sensemap_i = os.path.join(dir_in_root, case_name, file_sensemap)

        if verbose:
            print("Processing [%d] %s..." % (i_case, case_name))

        if i_case < i_train_1:
            dir_out_i = os.path.join(dir_out, "train")
        elif i_case < i_validate_1:
            dir_out_i = os.path.join(dir_out, "validate")
        else:
            dir_out_i = os.path.join(dir_out, "test")

        i_case = i_case + 1

        if not os.path.exists(file_kspace_i + ".hdr"):
            print("skipping due to kspace not existing in this folder")
            continue

        kspace = np.squeeze(cfl.read(file_kspace_i))
        print("original kspace shape")
        print(kspace.shape)

        shape_x = kspace.shape[3]
        shape_y = kspace.shape[2]
        shape_z = kspace.shape[1]
        num_coils = kspace.shape[0]

        if num_coils is not 32:
            print("skipping due to incorrect number of coils")
            continue

        if min_shape[0] == kspace.shape[1] and min_shape[1] == kspace.shape[2]:
            if verbose:
                print("  Slice shape: (%d, %d)" %
                      (kspace.shape[1], kspace.shape[2]))
                print("  Num channels: %d" % kspace.shape[0])

            #  shape_x = kspace.shape[-1]
            # fix +1, -1 modulation along readout direction
            # for n in range(shape_x):
            #     modulation = (-1)**n
            #     kspace[:,:,:,n] = kspace[:,:,:,n]*np.exp(-1j*modulation)
            # print("kspace shape after modulation")
            # print(kspace.shape)
            # readout in kx
            kspace = fftc.ifftc(kspace, axis=-1)

            cmd_flags = ""
            if crop_maps:
                cmd_flags = cmd_flags + " -c 1e-9"
            # smoothing flag
            cmd_flags = cmd_flags + (" -S")

            cmd_flags = cmd_flags + (" -m %d" % num_maps)
            cmd = "%s ecalib %s %s %s" % (
                BIN_BART,
                cmd_flags,
                file_kspace_i,
                file_sensemap_i,
            )
            if verbose:
                print("  Estimating sensitivity maps (bart espirit)...")
                print("    %s" % cmd)
            subprocess.check_call(["bash", "-c", cmd])
            sensemap = np.squeeze(cfl.read(file_sensemap_i))
            sensemap = np.expand_dims(sensemap, axis=0)
            sensemap = sensemap.astype(np.complex64)

            if verbose:
                print("  Creating tfrecords (%d)..." % shape_x)
            for i_x in range(shape_x):
                file_out = os.path.join(
                    dir_out_i, "%s_x%03d.tfrecords" % (case_name, i_x))
                kspace_x = kspace[:, :, :, i_x]
                sensemap_x = sensemap[:, :, :, :, i_x]

                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        "name": _bytes_feature(str.encode(case_name)),
                        "xslice": _int64_feature(i_x),
                        "ks_shape_x": _int64_feature(kspace.shape[3]),
                        "ks_shape_y": _int64_feature(kspace.shape[2]),
                        "ks_shape_z": _int64_feature(kspace.shape[1]),
                        "ks_shape_c": _int64_feature(kspace.shape[0]),
                        "map_shape_x": _int64_feature(sensemap.shape[4]),
                        "map_shape_y": _int64_feature(sensemap.shape[3]),
                        "map_shape_z": _int64_feature(sensemap.shape[2]),
                        "map_shape_c": _int64_feature(sensemap.shape[1]),
                        "map_shape_m": _int64_feature(sensemap.shape[0]),
                        "ks": _bytes_feature(kspace_x.tostring()),
                        "map": _bytes_feature(sensemap_x.tostring()),
                    }))

                tf_writer = tf.python_io.TFRecordWriter(file_out)
                tf_writer.write(example.SerializeToString())
                tf_writer.close()
        else:
            print("skipping due to wrong slice dimensions")
Exemple #5
0
def setup_data_tfrecords_3d(
    dir_in_root,
    dir_out,
    data_divide=(0.8, 0.1, 0.2),
    min_shape=[80, 180],
    num_maps=1,
    crop_maps=False,
    verbose=False,
    shuffle=True,
):
    """Setups training data as tfrecords.

    prep_data.setup_data('/mnt/raid3/data/Studies_DCE/recon-ccomp6/',
        '/mnt/raid3/jycheng/Project/deepspirit/data/train/', verbose=True)
    """

    # Check for two echos in here
    # Use glob to find if have echo01

    if verbose:
        print("Directory names:")
        print("  Input root:  %s" % dir_in_root)
        print("  Output root: %s" % dir_out)

    # edits for /mnt/dense/data/MFAST_DCE and /home_local/ekcole/MFAST_DCE
    dir_kspace = "sort-ccomp6"
    dir_map = "recon-ccomp6"

    dir_cases = os.path.join(dir_in_root, dir_kspace)

    file_kspace = "ks_sorted"
    file_sensemap = "map"

    case_list = os.listdir(dir_cases)
    if shuffle is True:
        random.shuffle(case_list)
    else:
        print("don't shuffle dataset")
    num_cases = len(case_list)

    i_train_1 = np.round(data_divide[0] * num_cases).astype(int)
    i_validate_0 = i_train_1 + 1
    i_validate_1 = np.round(
        data_divide[1] * num_cases).astype(int) + i_validate_0

    if not os.path.exists(dir_out):
        os.mkdir(dir_out)
    if not os.path.exists(os.path.join(dir_out, "train")):
        os.mkdir(os.path.join(dir_out, "train"))
    if not os.path.exists(os.path.join(dir_out, "validate")):
        os.mkdir(os.path.join(dir_out, "validate"))
    if not os.path.exists(os.path.join(dir_out, "test")):
        os.mkdir(os.path.join(dir_out, "test"))

    i_case = 0
    for case_name in case_list:
        file_kspace_i = os.path.join(dir_in_root, dir_kspace, case_name,
                                     file_kspace)
        file_sensemap_i = os.path.join(dir_in_root, dir_map, case_name,
                                       file_sensemap)
        file_sensemap_i_check = os.path.join(dir_in_root, dir_map, case_name,
                                             file_sensemap + ".cfl")
        if i_case < i_train_1:
            dir_out_i = os.path.join(dir_out, "train")
        elif i_case < i_validate_1:
            dir_out_i = os.path.join(dir_out, "validate")
        else:
            dir_out_i = os.path.join(dir_out, "test")
        print("dir out")
        print(dir_out_i)

        if verbose:
            print("Processing [%d] %s..." % (i_case, case_name))
        i_case = i_case + 1

        # if no map, skip this case
        # do nothing
        if not path.exists(file_sensemap_i_check):
            print("Sensitivity map does not exist")
            continue

        # get dims from .hdr
        h = open(file_kspace_i + ".hdr", "r")
        h.readline()  # skip
        l = h.readline()
        h.close()
        dims = [int(i) for i in l.split()]
        print(dims)
        ky = dims[1]
        kz = dims[2]
        if ky != 180 or kz != 80:
            print("wrong dimensions")
            continue

        kspace = np.squeeze(cfl.read(file_kspace_i))

        # it wants coils x y z frames
        kspace = np.transpose(kspace, [1, -1, -2, 2, 0])

        if verbose:
            print("  Slice shape: (%d, %d)" %
                  (kspace.shape[2], kspace.shape[3]))
            print("  Num channels: %d" % kspace.shape[0])
            print("  Num frames: %d" % kspace.shape[-1])
        # number of frames
        shape_f = kspace.shape[-1]
        # number of slices in x direction
        num_slices = kspace.shape[1]

        kspace = fftc.ifftc(kspace, axis=1)
        kspace = kspace.astype(np.complex64)

        print("Exists")
        sensemap = np.squeeze(cfl.read(file_sensemap_i))
        sensemap = sensemap[0, :, :, :, :]
        # it has coils z y x
        # 6, 80, 156, 192
        #         print(sensemap.shape)
        # we want coils x y z
        sensemap = np.transpose(sensemap, [0, -1, 2, 1])
        sensemap = np.expand_dims(sensemap, axis=0)
        sensemap = sensemap.astype(np.complex64)

        if verbose:
            print("  Creating tfrecords (%d)..." % num_slices)
        # for 2D plus time, only iterate over slices, not time frames

        for i_slice in range(num_slices):
            # normalization across time frames
            kspace_x = kspace[:, i_slice, :, :, :]
            file_out = os.path.join(
                dir_out_i, "%s_x%03d.tfrecords" % (case_name, i_slice))
            #             kspace_x = kspace[:, i_x, :, :, i_f]/max_frames
            sensemap_x = sensemap[:, :, i_slice, :, :]
            example = tf.train.Example(features=tf.train.Features(
                feature={
                    "name": _bytes_feature(str.encode(case_name)),
                    "slice": _int64_feature(i_slice),
                    "ks_shape_x": _int64_feature(kspace.shape[1]),
                    "ks_shape_y": _int64_feature(kspace.shape[2]),
                    "ks_shape_z": _int64_feature(kspace.shape[3]),
                    "ks_shape_t": _int64_feature(kspace.shape[4]),
                    "ks_shape_c": _int64_feature(kspace.shape[0]),
                    "map_shape_x": _int64_feature(sensemap.shape[2]),
                    "map_shape_y": _int64_feature(sensemap.shape[3]),
                    "map_shape_z": _int64_feature(sensemap.shape[4]),
                    "map_shape_c": _int64_feature(sensemap.shape[1]),
                    "map_shape_m": _int64_feature(sensemap.shape[0]),
                    "ks": _bytes_feature(kspace_x.tostring()),
                    "map": _bytes_feature(sensemap_x.tostring()),
                }))

            tf_writer = tf.python_io.TFRecordWriter(file_out)
            tf_writer.write(example.SerializeToString())
            tf_writer.close()
Exemple #6
0
def setup_data_tfrecords_DCE(
    dir_in_root,
    dir_out,
    data_divide=(0.75, 0.05, 0.2),
    min_shape=[80, 180],
    num_maps=1,
    crop_maps=False,
    verbose=False,
    shuffle=True,
):
    """Setups training data as tfrecords.

    prep_data.setup_data('/mnt/raid3/data/Studies_DCE/recon-ccomp6/',
        '/mnt/raid3/jycheng/Project/deepspirit/data/train/', verbose=True)
    """
    if verbose:
        print("Directory names:")
        print("  Input root:  %s" % dir_in_root)
        print("  Output root: %s" % dir_out)

    # edits for /mnt/dense/data/MFAST_DCE and /home_local/ekcole/MFAST_DCE
    dir_kspace = "sort-ccomp6"
    dir_map = "recon-ccomp6"
    #     dir_kspace = "dce-ccomp6"

    dir_cases = os.path.join(dir_in_root, dir_kspace)

    file_kspace = "ks_sorted"
    file_sensemap = "map"

    case_list = os.listdir(dir_cases)
    # shuffle cases (patients)
    if shuffle is True:
        random.shuffle(case_list)
    else:
        print("don't shuffle dataset")
    num_cases = len(case_list)

    i_train_1 = np.round(data_divide[0] * num_cases).astype(int)
    i_validate_0 = i_train_1 + 1
    i_validate_1 = np.round(
        data_divide[1] * num_cases).astype(int) + i_validate_0

    if not os.path.exists(dir_out):
        os.mkdir(dir_out)
    if not os.path.exists(os.path.join(dir_out, "train")):
        os.mkdir(os.path.join(dir_out, "train"))
    if not os.path.exists(os.path.join(dir_out, "validate")):
        os.mkdir(os.path.join(dir_out, "validate"))
    if not os.path.exists(os.path.join(dir_out, "test")):
        os.mkdir(os.path.join(dir_out, "test"))

    i_case = 0
    for case_name in case_list:
        file_kspace_i = os.path.join(dir_in_root, dir_kspace, case_name,
                                     file_kspace)
        file_sensemap_i = os.path.join(dir_in_root, dir_map, case_name,
                                       file_sensemap)
        file_sensemap_i_check = os.path.join(dir_in_root, dir_map, case_name,
                                             file_sensemap + ".cfl")

        # if no map, skip this case and do nothing
        if not path.exists(file_sensemap_i_check):
            print("Does not exist")
            continue

        # get dims from .hdr
        h = open(file_kspace_i + ".hdr", "r")
        h.readline()  # skip
        l = h.readline()
        h.close()
        dims = [int(i) for i in l.split()]
        print(dims)
        ky = dims[1]
        kz = dims[2]
        if ky != 180 or kz != 80:
            print("wrong dimensions")
            continue

        if i_case < i_train_1:
            dir_out_i = os.path.join(dir_out, "train")
        elif i_case < i_validate_1:
            dir_out_i = os.path.join(dir_out, "validate")
        else:
            dir_out_i = os.path.join(dir_out, "test")

        if verbose:
            print("Processing [%d] %s..." % (i_case, case_name))
        i_case = i_case + 1

        #         if(i_case >= 50):
        #             break

        kspace = np.squeeze(cfl.read(file_kspace_i))
        # it wants coils x y z frames
        kspace = np.transpose(kspace, [1, -1, -2, 2, 0])
        if verbose:
            print("  Slice shape: (%d, %d)" %
                  (kspace.shape[2], kspace.shape[3]))
            print("  Num channels: %d" % kspace.shape[0])
            print("  Num frames: %d" % kspace.shape[-1])
        # number of frames
        shape_f = kspace.shape[-1]
        # number of slices in x direction
        shape_x = kspace.shape[1]

        kspace = fftc.ifftc(kspace, axis=1)
        kspace = kspace.astype(np.complex64)

        sensemap = np.squeeze(cfl.read(file_sensemap_i))
        sensemap = sensemap[0, :, :, :, :]
        # we want coils x y z
        sensemap = np.transpose(sensemap, [0, -1, 2, 1])
        sensemap = np.expand_dims(sensemap, axis=0)
        sensemap = sensemap.astype(np.complex64)

        if verbose:
            print("  Creating tfrecords (%d)..." % shape_x)
        # Need to iterate over both z and frames

        for i_x in range(shape_x):
            # normalization across time frames
            kspace_x = kspace[:, i_x, :, :, :]
            max_frames = np.max(np.abs(kspace_x))
            #             print(max_frames)
            for i_f in range(shape_f):
                file_out = os.path.join(
                    dir_out_i,
                    "%s_x%03d_f%03d.tfrecords" % (case_name, i_x, i_f))
                kspace_x = kspace[:, i_x, :, :, i_f] / max_frames
                sensemap_x = sensemap[:, :, i_x, :, :]

                #                 #save images as pngs to check if time frames shuffling is done here
                # #                 ks = np.squeeze(kspace_x)
                #                 ks = kspace_x
                #                 print(ks.shape)
                #                 ks = np.transpose(ks, [1,2,0])
                # #                 ks = np.expand_dims(ks, 0)
                #                 ks = tf.convert_to_tensor(ks)
                #                 print("ks")
                #                 print(ks)

                #                 sense = np.squeeze(sensemap_x)
                #                 print(sense.shape)
                #                 sense = np.transpose(sense, [1,2,0])
                #                 sense = np.expand_dims(sense, -2)
                #                 sense = tf.convert_to_tensor(sense)
                #                 print("sensemap")
                #                 print(sense)

                #                 image_x = tf_util.model_transpose(ks, sense)

                #                 sess = tf.Session()

                #                 # Evaluate the tensor `c`.
                #                 image_x = sess.run(image_x)

                #                 filename = dir_out_i + '/images/case' + str(i_x) + '_f' + str(i_f) + '.png'
                #                 print(filename)
                #                 scipy.misc.imsave(filename, np.squeeze(np.abs(image_x)))

                # at this stage, the images were not shuffled

                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        "name": _bytes_feature(str.encode(case_name)),
                        "xslice": _int64_feature(i_x),
                        "ks_shape_x": _int64_feature(kspace.shape[1]),
                        "ks_shape_y": _int64_feature(kspace.shape[2]),
                        "ks_shape_z": _int64_feature(kspace.shape[3]),
                        "ks_shape_c": _int64_feature(kspace.shape[0]),
                        "map_shape_x": _int64_feature(sensemap.shape[2]),
                        "map_shape_y": _int64_feature(sensemap.shape[3]),
                        "map_shape_z": _int64_feature(sensemap.shape[4]),
                        "map_shape_c": _int64_feature(sensemap.shape[1]),
                        "map_shape_m": _int64_feature(sensemap.shape[0]),
                        "ks": _bytes_feature(kspace_x.tostring()),
                        "map": _bytes_feature(sensemap_x.tostring()),
                    }))

                tf_writer = tf.python_io.TFRecordWriter(file_out)
                tf_writer.write(example.SerializeToString())
                tf_writer.close()