Esempio n. 1
0
def convert_data2mnc(dataset_info, contrast='t1'):
    path_template = dataset_info['path_template']
    list_subjects = dataset_info['subjects']

    path_template_mask = create_mask_template(dataset_info, contrast)

    output_list = open('subjects.csv', "wb")
    writer = csv.writer(output_list, delimiter=',', quotechar='"', quoting=csv.QUOTE_ALL)

    timer_convert = sct.Timer(len(list_subjects))
    timer_convert.start()
    for subject_name in list_subjects:
        fname_nii = path_template + subject_name + '_' + contrast + '.nii.gz'
        fname_mnc = path_template + subject_name + '_' + contrast + '.mnc'

        # if mask already present, deleting it
        if os.path.isfile(fname_mnc):
            os.remove(fname_mnc)

        sct.run('nii2mnc ' + fname_nii + ' ' + fname_mnc)

        writer.writerow(fname_mnc + ',' + path_template_mask)

        timer_convert.add_iteration()
    timer_convert.stop()

    output_list.close()
Esempio n. 2
0
def generate_centerline(dataset_info, contrast='t1', regenerate=False):
    """
    This function generates spinal cord centerline from binary images (either an image of centerline or segmentation)
    :param dataset_info: dictionary containing dataset information
    :param contrast: {'t1', 't2'}
    :return list of centerline objects
    """
    path_data = dataset_info['path_data']
    list_subjects = dataset_info['subjects']
    list_centerline = []

    current_path = os.getcwd()

    timer_centerline = sct.Timer(len(list_subjects))
    timer_centerline.start()
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        fname_image_centerline = path_data_subject + contrast + dataset_info['suffix_centerline'] + '.nii.gz'
        fname_image_disks = path_data_subject + contrast + dataset_info['suffix_disks'] + '.nii.gz'

        # go to output folder
        sct.printv('\nExtracting centerline from ' + path_data_subject)
        os.chdir(path_data_subject)

        fname_centerline = 'centerline'
        # if centerline exists, we load it, if not, we compute it
        if os.path.isfile(fname_centerline + '.npz') and not regenerate:
            centerline = Centerline(fname=path_data_subject + fname_centerline + '.npz')
        else:
            # extracting intervertebral disks
            im = Image(fname_image_disks)
            coord = im.getNonZeroCoordinates(sorting='z', reverse_coord=True)
            coord_physical = []
            for c in coord:
                if c.value <= 22 or c.value in [48, 49, 50, 51, 52]:  # 22 corresponds to L2
                    c_p = im.transfo_pix2phys([[c.x, c.y, c.z]])[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)

            # extracting centerline from binary image and create centerline object with vertebral distribution
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_image_centerline, algo_fitting='nurbs',
                verbose=0, nurbs_pts_number=4000, all_slices=False, phys_coordinates=True, remove_outliers=False)
            centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                                    x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)
            centerline.compute_vertebral_distribution(coord_physical)
            centerline.save_centerline(fname_output=fname_centerline)

        list_centerline.append(centerline)
        timer_centerline.add_iteration()
    timer_centerline.stop()

    os.chdir(current_path)

    return list_centerline
Esempio n. 3
0
def copy_preprocessed_images(dataset_info, contrast='t1'):
    path_data = dataset_info['path_data']
    path_template = dataset_info['path_template']
    list_subjects = dataset_info['subjects']

    fname_in = contrast + '_straight_norm.nii.gz'

    timer_copy = sct.Timer(len(list_subjects))
    timer_copy.start()
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        os.chdir(path_data_subject)
        shutil.copy(fname_in, path_template + subject_name + '_' + contrast + '.nii.gz')
        timer_copy.add_iteration()
    timer_copy.stop()
Esempio n. 4
0
def straighten_all_subjects(dataset_info, normalized=False, contrast='t1'):
    """
    This function straighten all images based on template centerline
    :param dataset_info: dictionary containing dataset information
    :param normalized: True if images were normalized before straightening
    :param contrast: {'t1', 't2'}
    """
    path_data = dataset_info['path_data']
    path_template = dataset_info['path_template']
    list_subjects = dataset_info['subjects']

    if normalized:
        fname_in = contrast + '_norm.nii.gz'
        fname_out = contrast + '_straight_norm.nii.gz'
    else:
        fname_in = contrast + '.nii.gz'
        fname_out = contrast + '_straight.nii.gz'

    # straightening of each subject on the new template
    timer_straightening = sct.Timer(len(list_subjects))
    timer_straightening.start()
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'

        # go to output folder
        sct.printv('\nStraightening ' + path_data_subject)
        os.chdir(path_data_subject)
        sct.run('sct_straighten_spinalcord'
                ' -i ' + fname_in +
                ' -s ' + contrast + dataset_info['suffix_centerline'] + '.nii.gz'
                ' -disks-input ' + contrast + dataset_info['suffix_disks'] + '.nii.gz'
                ' -ref ' + path_template + 'template_centerline.nii.gz'
                ' -disks-ref ' + path_template + 'template_disks.nii.gz'
                ' -disable-straight2curved'
                ' -param threshold_distance=1', verbose=1)

        image_straight = Image(sct.add_suffix(fname_in, '_straight'))
        image_straight.setFileName(fname_out)
        image_straight.save(type='float32')

        timer_straightening.add_iteration()
    timer_straightening.stop()
Esempio n. 5
0
def normalize_intensity_template(dataset_info, fname_template_centerline=None, contrast='t1', verbose=1):
    """
    This function normalizes the intensity of the image inside the spinal cord
    :param fname_template: path to template image
    :param fname_template_centerline: path to template centerline (binary image or npz)
    :return:
    """

    path_data = dataset_info['path_data']
    list_subjects = dataset_info['subjects']
    path_template = dataset_info['path_template']

    average_intensity = []
    intensity_profiles = {}

    timer_profile = sct.Timer(len(list_subjects))
    timer_profile.start()

    # computing the intensity profile for each subject
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        if fname_template_centerline is None:
            fname_image = path_data_subject + contrast + '.nii.gz'
            fname_image_centerline = path_data_subject + contrast + dataset_info['suffix_centerline'] + '.nii.gz'
        else:
            fname_image = path_data_subject + contrast + '_straight.nii.gz'
            if fname_template_centerline.endswith('.npz'):
                fname_image_centerline = None
            else:
                fname_image_centerline = fname_template_centerline

        image = Image(fname_image)
        nx, ny, nz, nt, px, py, pz, pt = image.dim

        if fname_image_centerline is not None:
            # open centerline from template
            number_of_points_in_centerline = 4000
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_image_centerline, algo_fitting='nurbs', verbose=0,
                nurbs_pts_number=number_of_points_in_centerline,
                all_slices=False, phys_coordinates=True, remove_outliers=True)
            centerline_template = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                                             x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)
        else:
            centerline_template = Centerline(fname=fname_template_centerline)

        x, y, z, xd, yd, zd = centerline_template.average_coordinates_over_slices(image)

        # Compute intensity values
        z_values, intensities = [], []
        extend = 1  # this means the mean intensity of the slice will be calculated over a 3x3 square
        for i in range(len(z)):
            coord_z = image.transfo_phys2pix([[x[i], y[i], z[i]]])[0]
            z_values.append(coord_z[2])
            intensities.append(np.mean(image.data[coord_z[0] - extend - 1:coord_z[0] + extend, coord_z[1] - extend - 1:coord_z[1] + extend, coord_z[2]]))

        # for the slices that are not in the image, extend min and max values to cover the whole image
        min_z, max_z = min(z_values), max(z_values)
        intensities_temp = copy(intensities)
        z_values_temp = copy(z_values)
        for cz in range(nz):
            if cz not in z_values:
                z_values_temp.append(cz)
                if cz < min_z:
                    intensities_temp.append(intensities[z_values.index(min_z)])
                elif cz > max_z:
                    intensities_temp.append(intensities[z_values.index(max_z)])
                else:
                    print 'error...', cz
        intensities = intensities_temp
        z_values = z_values_temp

        # Preparing data for smoothing
        arr_int = [[z_values[i], intensities[i]] for i in range(len(z_values))]
        arr_int.sort(key=lambda x: x[0])  # and make sure it is ordered with z

        def smooth(x, window_len=11, window='hanning'):
            """smooth the data using a window with requested size.
            """

            if x.ndim != 1:
                raise ValueError, "smooth only accepts 1 dimension arrays."

            if x.size < window_len:
                raise ValueError, "Input vector needs to be bigger than window size."

            if window_len < 3:
                return x

            if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
                raise ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'"

            s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]]
            if window == 'flat':  # moving average
                w = np.ones(window_len, 'd')
            else:
                w = eval('np.' + window + '(window_len)')

            y = np.convolve(w / w.sum(), s, mode='same')
            return y[window_len - 1:-window_len + 1]

        # Smoothing
        intensities = [c[1] for c in arr_int]
        intensity_profile_smooth = smooth(np.array(intensities), window_len=50)
        average_intensity.append(np.mean(intensity_profile_smooth))

        intensity_profiles[subject_name] = intensity_profile_smooth

        if verbose == 2:
            import matplotlib.pyplot as plt
            plt.figure()
            plt.title(subject_name)
            plt.plot(intensities)
            plt.plot(intensity_profile_smooth)
            plt.show()

    # set the average image intensity over the entire dataset
    average_intensity = 1000.0

    # normalize the intensity of the image based on spinal cord
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        fname_image = path_data_subject + contrast + '_straight.nii.gz'

        image = Image(fname_image)
        nx, ny, nz, nt, px, py, pz, pt = image.dim

        image_image_new = image.copy()
        image_image_new.changeType(type='float32')
        for i in range(nz):
            image_image_new.data[:, :, i] *= average_intensity / intensity_profiles[subject_name][i]

        # Save intensity normalized template
        fname_image_normalized = sct.add_suffix(fname_image, '_norm')
        image_image_new.setFileName(fname_image_normalized)
        image_image_new.save()
Esempio n. 6
0
def compute_properties_along_centerline(fname_seg_image,
                                        property_list,
                                        fname_disks_image=None,
                                        smooth_factor=5.0,
                                        interpolation_mode=0,
                                        remove_temp_files=1,
                                        verbose=1):

    # Check list of properties
    # If diameters is in the list, compute major and minor axis length and check orientation
    compute_diameters = False
    property_list_local = list(property_list)
    if 'diameters' in property_list_local:
        compute_diameters = True
        property_list_local.remove('diameters')
        property_list_local.append('major_axis_length')
        property_list_local.append('minor_axis_length')
        property_list_local.append('orientation')

    # TODO: make sure fname_segmentation and fname_disks are in the same space
    # create temporary folder and copying data
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = sct.slash_at_the_end(
        'tmp.' + time.strftime("%y%m%d%H%M%S") + '_' +
        str(randint(1, 1000000)), 1)
    sct.run('mkdir ' + path_tmp, verbose)

    sct.run('cp ' + fname_seg_image + ' ' + path_tmp)
    if fname_disks_image is not None:
        sct.run('cp ' + fname_disks_image + ' ' + path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    fname_segmentation = os.path.abspath(fname_seg_image)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # Change orientation of the input centerline into RPI
    sct.printv('\nOrient centerline to RPI orientation...', verbose)
    im_seg = Image(file_data + ext_data)
    fname_segmentation_orient = 'segmentation_rpi' + ext_data
    image = set_orientation(im_seg, 'RPI')
    image.setFileName(fname_segmentation_orient)
    image.save()

    # Initiating some variables
    nx, ny, nz, nt, px, py, pz, pt = image.dim
    resolution = 0.5
    properties = {key: [] for key in property_list_local}
    properties['incremental_length'] = []
    properties['distance_from_C1'] = []
    properties['vertebral_level'] = []
    properties['z_slice'] = []

    # compute the spinal cord centerline based on the spinal cord segmentation
    number_of_points = 5 * nz
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        fname_segmentation_orient,
        algo_fitting='nurbs',
        verbose=verbose,
        nurbs_pts_number=number_of_points,
        all_slices=False,
        phys_coordinates=True,
        remove_outliers=True)
    centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                            x_centerline_deriv, y_centerline_deriv,
                            z_centerline_deriv)

    # Compute vertebral distribution along centerline based on position of intervertebral disks
    if fname_disks_image is not None:
        fname_disks = os.path.abspath(fname_disks_image)
        path_data, file_data, ext_data = sct.extract_fname(fname_disks)
        im_disks = Image(file_data + ext_data)
        fname_disks_orient = 'disks_rpi' + ext_data
        image_disks = set_orientation(im_disks, 'RPI')
        image_disks.setFileName(fname_disks_orient)
        image_disks.save()

        image_disks = Image(fname_disks_orient)
        coord = image_disks.getNonZeroCoordinates(sorting='z',
                                                  reverse_coord=True)
        coord_physical = []
        for c in coord:
            c_p = image_disks.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            c_p.append(c.value)
            coord_physical.append(c_p)
        centerline.compute_vertebral_distribution(coord_physical)

    sct.printv('Computing spinal cord shape along the spinal cord...')
    timer_properties = sct.Timer(
        number_of_iteration=centerline.number_of_points)
    timer_properties.start()
    # Extracting patches perpendicular to the spinal cord and computing spinal cord shape
    for index in range(centerline.number_of_points):
        # value_out = -5.0
        value_out = 0.0
        current_patch = centerline.extract_perpendicular_square(
            image,
            index,
            resolution=resolution,
            interpolation_mode=interpolation_mode,
            border='constant',
            cval=value_out)

        # check for pixels close to the spinal cord segmentation that are out of the image
        from skimage.morphology import dilation
        patch_zero = np.copy(current_patch)
        patch_zero[patch_zero == value_out] = 0.0
        patch_borders = dilation(patch_zero) - patch_zero
        """
        if np.count_nonzero(patch_borders + current_patch == value_out + 1.0) != 0:
            c = image.transfo_phys2pix([centerline.points[index]])[0]
            print 'WARNING: no patch for slice', c[2]
            timer_properties.add_iteration()
            continue
        """

        sc_properties = properties2d(patch_zero, [resolution, resolution])
        if sc_properties is not None:
            properties['incremental_length'].append(
                centerline.incremental_length[index])
            if fname_disks_image is not None:
                properties['distance_from_C1'].append(
                    centerline.dist_points[index])
                properties['vertebral_level'].append(
                    centerline.l_points[index])
            properties['z_slice'].append(
                image.transfo_phys2pix([centerline.points[index]])[0][2])
            for property_name in property_list_local:
                properties[property_name].append(sc_properties[property_name])
        else:
            c = image.transfo_phys2pix([centerline.points[index]])[0]
            print 'WARNING: no properties for slice', c[2]

        timer_properties.add_iteration()
    timer_properties.stop()

    # Adding centerline to the properties for later use
    properties['centerline'] = centerline

    # We assume that the major axis is in the right-left direction
    # this script checks the orientation of the spinal cord and invert axis if necessary to make sure the major axis is right-left
    if compute_diameters:
        diameter_major = properties['major_axis_length']
        diameter_minor = properties['minor_axis_length']
        orientation = properties['orientation']
        for i, orientation_item in enumerate(orientation):
            if -45.0 < orientation_item < 45.0:
                continue
            else:
                temp = diameter_minor[i]
                properties['minor_axis_length'][i] = diameter_major[i]
                properties['major_axis_length'][i] = temp

        properties['RL_diameter'] = properties['major_axis_length']
        properties['AP_diameter'] = properties['minor_axis_length']
        del properties['major_axis_length']
        del properties['minor_axis_length']

    # smooth the spinal cord shape with a gaussian kernel if required
    # TODO: not all properties can be smoothed
    if smooth_factor != 0.0:  # smooth_factor is in mm
        import scipy
        window = scipy.signal.hann(smooth_factor /
                                   np.mean(centerline.progressive_length))
        for property_name in property_list_local:
            properties[property_name] = scipy.signal.convolve(
                properties[property_name], window,
                mode='same') / np.sum(window)

    if compute_diameters:
        property_list_local.remove('major_axis_length')
        property_list_local.remove('minor_axis_length')
        property_list_local.append('RL_diameter')
        property_list_local.append('AP_diameter')
        property_list = property_list_local

    # Display properties on the referential space. Requires intervertebral disks
    if verbose == 2:
        x_increment = 'distance_from_C1'
        if fname_disks_image is None:
            x_increment = 'incremental_length'

        # Display the image and plot all contours found
        fig, axes = plt.subplots(len(property_list_local),
                                 sharex=True,
                                 sharey=False)
        for k, property_name in enumerate(property_list_local):
            axes[k].plot(properties[x_increment], properties[property_name])
            axes[k].set_ylabel(property_name)

        if fname_disks_image is not None:
            properties[
                'distance_disk_from_C1'] = centerline.distance_from_C1label  # distance between each disk and C1 (or first disk)
            xlabel_disks = [
                centerline.convert_vertlabel2disklabel[label]
                for label in properties['distance_disk_from_C1']
            ]
            xtick_disks = [
                properties['distance_disk_from_C1'][label]
                for label in properties['distance_disk_from_C1']
            ]
            plt.xticks(xtick_disks, xlabel_disks, rotation=30)
        else:
            axes[-1].set_xlabel('Position along the spinal cord (in mm)')

        plt.show()

    # Removing temporary folder
    os.chdir('..')
    shutil.rmtree(path_tmp, ignore_errors=True)

    return property_list, properties
Esempio n. 7
0
def main(argv=None):  # pylint: disable=unused-argument

    if gfile.Exists(FLAGS.train_dir):
        gfile.DeleteRecursively(FLAGS.train_dir)
    gfile.MakeDirs(FLAGS.train_dir)

    with tf.Graph().as_default():
        # Setting U-net parameters
        depth = 3

        # Make sure image size corresponds to requirements
        # "select the input tile size such that all 2x2 max-pooling operationsare applied to a
        # layer with an even x- and y-size."
        image_size_temp = IMAGE_SIZE
        for i in range(depth):
            if image_size_temp % 2 != 0:
                sct.printv('ERROR: image size must satisfy requirements (select the input tile size such that all 2x2 '
                           'max-pooling operationsare applied to a layer with an even x- and y-size.)', type='error')
            image_size_temp = (image_size_temp) / 2
        image_size_bottom = image_size_temp

        # Compute the size of the image segmentation, based on depth
        for i in range(depth):
            image_size_temp *= 2
        segmentation_image_size = image_size_temp
        # offset_images = (IMAGE_SIZE - segmentation_image_size) / 2

        sct.printv('Original image size = ' + str(IMAGE_SIZE))
        sct.printv('Image size at bottom layer = ' + str(image_size_bottom))
        sct.printv('Image size of output = ' + str(segmentation_image_size))

        # Extracting datasets
        list_data = extract_data(TRAINING_SOURCE_DATA)
        list_labels = extract_label(TRAINING_LABELS_DATA)
        list_test_data = extract_data(TEST_SOURCE_DATA)
        list_test_labels = extract_label(TEST_LABELS_DATA)

        # Generate a validation set
        validation_data = list_data[:VALIDATION_SIZE]
        validation_labels = list_labels[:VALIDATION_SIZE]
        list_data = list_data[VALIDATION_SIZE:]
        list_labels = list_labels[VALIDATION_SIZE:]
        num_epochs = NUM_EPOCHS
        train_size = len(list_labels)

        # This is where training samples and labels are fed to the graph.
        # These placeholder nodes will be fed a batch of training data at each
        # training step using the {feed_dict} argument to the Run() call below.
        train_data_node = tf.placeholder(tf.float32, shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
        train_labels_node = tf.placeholder(tf.float32, shape=(BATCH_SIZE * segmentation_image_size * segmentation_image_size, NUM_LABELS))
        train_labels_weights = tf.placeholder(tf.float32, shape=(BATCH_SIZE * segmentation_image_size * segmentation_image_size))
        # For the validation and test data, we'll just hold the entire dataset in one constant node.

        unet = UNetModel(IMAGE_SIZE, depth)

        # Training computation: logits + cross-entropy loss.
        logits = unet.model(train_data_node, True)
        loss = tf.reduce_mean(tf.mul(train_labels_weights, tf.nn.softmax_cross_entropy_with_logits(logits, train_labels_node)))
        tf.scalar_summary('Loss', loss)

        # Optimizer: set up a variable that's incremented once per batch and
        # controls the learning rate decay.
        batch = tf.Variable(0, trainable=False)
        # Decay once per epoch, using an exponential schedule starting at 0.01.
        learning_rate = tf.train.exponential_decay(0.001,  # Base learning rate.
                                                   batch * BATCH_SIZE,  # Current index into the dataset.
                                                   train_size,  # Decay step.
                                                   0.95,  # Decay rate.
                                                   staircase=True)
        tf.scalar_summary('Learning rate', learning_rate)

        error_rate_batch = tf.Variable(0.0, name='batch_error_rate', trainable=False)
        tf.scalar_summary('Batch error rate', error_rate_batch)

        error_rate_validation = tf.Variable(0.0, name='validation_error_rate', trainable=False)
        tf.scalar_summary('Validation error rate', error_rate_validation)

        # Use simple gradient descent for the optimization.
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=batch)
        #optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=batch)
        #optimizer = tf.train.MomentumOptimizer(learning_rate).minimize(loss, global_step=batch)

        # Predictions for the minibatch, validation set and test set.
        train_prediction = tf.nn.softmax(logits)

        # Create a local session to run this computation.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.merge_all_summaries()

        import multiprocessing as mp
        import time
        number_of_cores = mp.cpu_count()
        with tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement, inter_op_parallelism_threads=number_of_cores, intra_op_parallelism_threads=number_of_cores)) as s:
            # Run all the initializers to prepare the trainable parameters.
            init = tf.initialize_all_variables()
            s.run(init)
            tf.train.start_queue_runners(sess=s)
            summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=s.graph_def)

            sct.printv('\nShuffling batches!')
            number_of_step = int(num_epochs * train_size / BATCH_SIZE)
            steps = range(number_of_step)
            shuffle(steps)
            sct.printv('Initialized!')
            sct.printv('Number of step = ' + str(number_of_step))
            timer_training = sct.Timer(number_of_step)
            timer_training.start()
            # Loop through training steps.
            for i, step in enumerate(steps):
                sct.printv('Step '+ str(i) + '/' + str(len(steps)))
                sct.printv('Epoch ' + str(round(float(i) * BATCH_SIZE / train_size, 2)) + ' %')
                timer_training.iterations_done(i)
                # Compute the offset of the current minibatch in the data.
                # Note that we could use better randomization across epochs.
                offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
                batch_data = extract_data(TRAINING_SOURCE_DATA, list_images=list_data[offset:(offset + BATCH_SIZE)], verbose=0)
                batch_labels, batch_labels_weights = extract_label(TRAINING_LABELS_DATA, segmentation_image_size, list_labels[offset:(offset + BATCH_SIZE)], verbose=0)
                batch_labels = numpy.reshape(batch_labels, [batch_labels.shape[0] * batch_labels.shape[1] * batch_labels.shape[2], NUM_LABELS])
                batch_labels_weights = numpy.reshape(batch_labels_weights, [batch_labels_weights.shape[0] * batch_labels_weights.shape[1] * batch_labels_weights.shape[2]])
                # This dictionary maps the batch data (as a numpy array) to the
                # node in the graph is should be fed to.
                feed_dict = {train_data_node: batch_data, train_labels_node: batch_labels, train_labels_weights: batch_labels_weights}
                # Run the graph and fetch some of the nodes.
                _, l, lr, predictions = s.run([optimizer, loss, learning_rate, train_prediction], feed_dict=feed_dict)

                assert not numpy.isnan(l), 'Model diverged with loss = NaN'

                """
                if i % 100 == 0 or (i + 1) == FLAGS.max_steps:
                    checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                    saver.save(s, checkpoint_path, global_step=i)
                """

                if i != 0 and i % 50 == 0:
                    error_rate_batch_tens = error_rate_batch.assign(error_rate(predictions, batch_labels))
                    validation_data_b = extract_data(TRAINING_SOURCE_DATA, list_images=validation_data, verbose=0)
                    validation_labels_b, validation_labels_weights = extract_label(TRAINING_LABELS_DATA, segmentation_image_size, validation_labels, verbose=0)
                    validation_labels_b = numpy.reshape(validation_labels_b, [validation_labels_b.shape[0] * validation_labels_b.shape[1] * validation_labels_b.shape[2], NUM_LABELS])
                    validation_data_node = tf.constant(validation_data_b)
                    validation_prediction = tf.nn.softmax(unet.model(validation_data_node))
                    error_rate_validation_tens = error_rate_validation.assign(error_rate(validation_prediction.eval(), validation_labels_b))
                else:
                    error_rate_batch_tens = error_rate_validation.assign(error_rate_batch.eval())
                    error_rate_validation_tens = error_rate_validation.assign(error_rate_validation.eval())

                if i != 0 and i % 5 == 0:
                    result = s.run([summary_op, learning_rate, error_rate_batch_tens, error_rate_validation_tens], feed_dict=feed_dict)
                    summary_str = result[0]
                    sct.printv('Minibatch loss: %.6f, learning rate: %.6f, error batch %.3f, error validation %.3f' % (l, lr, error_rate_batch.eval(), error_rate_validation.eval()))
                    summary_writer.add_summary(summary_str, i)

                del batch_data
                del batch_labels
                del batch_labels_weights

            test_data = extract_data(TEST_SOURCE_DATA, list_images=list_test_data, verbose=0)
            test_labels, test_labels_weights = extract_label(TEST_LABELS_DATA, segmentation_image_size, list_test_labels, verbose=0)
            test_labels = numpy.reshape(test_labels, [test_labels.shape[0] * test_labels.shape[1] * test_labels.shape[2], NUM_LABELS])
            test_data_node = tf.constant(test_data)
            test_prediction = tf.nn.softmax(unet.model(test_data_node))
            # Finally print the result!
            test_error = error_rate(test_prediction.eval(), test_labels)
            sct.printv('Test error: ' + str(test_error))
            timer_training.printTotalTime()

            #savePredictions(result_test_prediction, output_path, list_test_data, segmentation_image_size)

            save_path = saver.save(s, output_path + 'model.ckpt')
            sct.printv('Model saved in file: ' + save_path)