Beispiel #1
0
def second_order_motifs(file_path, model, motif):
    # Define the motif
    # Create a dataset matching the following shape:
    # sequence : (1, 500, 4)
    # chromatin: (1, 140)
    simulated_chromatin = np.zeros((1, 130))
    # Start with creating a simple string sequence
    simulated_sequence = np.repeat('N', 500)
    simulated_sequence[250:256] = list(motif)
    # construct all possible 10 base-pair motifs
    # construct these motifs such that the 'CAGCTG' motif is flanked by all possible 2-mers on either side.
    bases = ['A', 'T', 'G', 'C']
    flanks = itertools.product(bases, repeat=4)
    score_list = []
    # kmer_dict is a default dict to store scores assigned to both f and rc sequences
    kmer_dict = defaultdict(list)
    # Score each flank
    for flank in flanks:
        word = flank[:2] + tuple(motif) + flank[2:]
        kmer = ''.join(word)
        # Embed motif in the 'zero' background word.
        simulated_sequence[245:255] = word
        simulated_sequence = list(simulated_sequence)
        sequence_onehot = make_onehot(simulated_sequence, 500)
        # Create a simulated input vector
        simulated_input = (sequence_onehot, simulated_chromatin)
        scores = get_embeddings_low_mem(model, simulated_input)
        sequence_score = scores[0][0]
        score_list.append(sequence_score)
        if reverse_complement(kmer) in kmer_dict:
            kmer_dict[reverse_complement(kmer)].append(sequence_score)
        else:
            kmer_dict[kmer].append(sequence_score)
    # Aggregate scores based on first order k-mers
    k = []
    score_list = []
    for kmer, val in kmer_dict.iteritems():
        for scores in val:
            seq = list(kmer)
            seq[0] = 'N'
            seq[9] = 'N'
            k.append(''.join(seq))
            score_list.append(scores)
    # Save the data files..
    dat = np.transpose(np.vstack((k, score_list)))
    fp = open(file_path, "a")
    np.savetxt(fp, dat, fmt='%s')
    fp.close()
    return dat
Beispiel #2
0
def motifs_in_ns(file_path, model, motif):
    # Start with creating a simple string sequence
    simulated_sequence = np.repeat('N', 500)
    simulated_sequence[250:256] = list(motif)
    # construct all possible 8 base-pair motifs
    # construct these motifs such that the 'CAGCTG' motif is flanked by all possible 2-mers on either side.
    # 1. construct all possible 4-bp sequences
    bases = ['A', 'T', 'G', 'C']
    flanks = itertools.product(bases, repeat=2)
    seq_list = []
    kmer_list = []
    for flank in flanks:
        word = flank[:1] + tuple(motif) + flank[1:]
        kmer = ''.join(word)
        simulated_sequence[246:254] = word
        simulated_sequence = list(simulated_sequence)
        sequence_onehot = make_onehot(simulated_sequence, 500)
        seq_list.append(sequence_onehot)
        kmer_list.append(kmer)
    # Get scores
    X = np.reshape(seq_list, (16, 500, 4))
    C = np.zeros(shape=(16, 130))
    simulated_input = (X, C)
    score_list = get_embeddings_low_mem(model, simulated_input)

    # I have to consider reverse complements
    kmer_dict = defaultdict(list)
    for kmer, score in zip(kmer_list, score_list):
        if reverse_complement(kmer) in kmer_dict:
            kmer_dict[reverse_complement(kmer)].append(score[0])
        else:
            kmer_dict[kmer].append(score[0])
    # Putting it back into a list
    k = []
    score_list = []
    for kmer in kmer_dict.iterkeys():
        kmer_dict[kmer] = np.mean(kmer_dict[kmer])
    for kmer, val in kmer_dict.iteritems():
        k.append(''.join(kmer))
        score_list.append(val)
    # Saving the data files..
    dat = np.transpose(np.vstack((k, score_list)))
    fp = open(file_path, "a")
    np.savetxt(fp, dat, fmt='%s')
    fp.close()
Beispiel #3
0
def first_order_motifs(file_path, model, motif, num):
    # Goal: Embedding 8bp k-mers into a 1000 randomly generated sequences.
    # Unlike the second_order motifs, will do this in parallel because of the larger number of inputs
    # Constructing the flanks.
    bases = ['A', 'T', 'G', 'C']
    flanks = itertools.product(bases, repeat=2)
    seq_list = []
    kmer_list = []
    for flank in flanks:
        word = flank[:1] + tuple(motif) + flank[1:]
        # The kmer is the string version of the word
        kmer = ''.join(word)
        seq_generator = random_sequences(num)
        for sequence in seq_generator:
            sequence[246:254] = word
            simulated_sequence = list(sequence)
            sequence_onehot = make_onehot(simulated_sequence, 500)
            # creating a simulated input vector
            seq_list.append(sequence_onehot)
            kmer_list.append(kmer)
    X = np.reshape(seq_list, (16 * num, 500, 4))
    C = np.zeros(shape=(16 * num, 130))
    simulated_input = (X, C)
    score_list = get_embeddings_low_mem(model, simulated_input)
    # Now, I have scores for each embedding motif, as well as the motif list.
    # Now use a dictionary to get the reverse complements sorted
    kmer_dict = defaultdict(list)
    for kmer, score in zip(kmer_list, score_list):
        if reverse_complement(kmer) in kmer_dict:
            kmer_dict[reverse_complement(kmer)].append(score[0])
        else:
            kmer_dict[kmer].append(score[0])
    # Having dealt with the reverse complement, put this in lists
    k = []
    score_list = []
    for kmer, val in kmer_dict.iteritems():
        for scores in val:
            k.append(''.join(kmer))
            score_list.append(scores)
    # Saving the data files..
    dat = np.transpose(np.vstack((k, score_list)))
    fp = open(file_path, "a")
    np.savetxt(fp, dat, fmt='%s')
    fp.close()
Beispiel #4
0
def modify_image_and_label(image, label, atlas, slice_thickness_this_subject):

    image_rescaled = []
    label_rescaled = []

    # ======================
    # rescale in 3d
    # ======================
    scale_vector = [
        slice_thickness_this_subject /
        0.7,  # for this axes, the resolution was kept unchanged during the initial 2D data preprocessing. but for the atlas (made from hcp labels), all of them have 0.7mm slice thickness
        1.0,  # the resolution along these 2 axes was made as required in the initial 2d data processing already
        1.0
    ]

    image_rescaled = rescale(image,
                             scale_vector,
                             order=1,
                             preserve_range=True,
                             multichannel=False,
                             mode='constant')

    label_onehot = utils.make_onehot(label, exp_config.nlabels)

    label_onehot_rescaled = rescale(label_onehot,
                                    scale_vector,
                                    order=1,
                                    preserve_range=True,
                                    multichannel=True,
                                    mode='constant')

    label_rescaled = np.argmax(label_onehot_rescaled, axis=-1)

    # =================
    # crop / pad
    # =================
    image_rescaled_cropped = utils.crop_or_pad_volume_to_size_along_x(
        image_rescaled, atlas.shape[0]).astype(np.float32)
    label_rescaled_cropped = utils.crop_or_pad_volume_to_size_along_x(
        label_rescaled, atlas.shape[0]).astype(np.uint8)

    return image_rescaled_cropped, label_rescaled_cropped
Beispiel #5
0
def run_training(log_dir,
                 image,
                 label,
                 atlas,
                 continue_run,
                 log_dir_first_TD_subject=''):

    # ============================
    # down sample the atlas - the losses will be evaluated in the downsampled space
    # ============================
    atlas_downsampled = rescale(atlas, [
        1 / exp_config.downsampling_factor_x, 1 /
        exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z
    ],
                                order=1,
                                preserve_range=True,
                                multichannel=True,
                                mode='constant')
    atlas_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot(
        atlas_downsampled, int(256 / exp_config.downsampling_factor_x))

    label_onehot = utils.make_onehot(label, exp_config.nlabels)
    label_onehot_downsampled = rescale(label_onehot, [
        1 / exp_config.downsampling_factor_x, 1 /
        exp_config.downsampling_factor_y, 1 / exp_config.downsampling_factor_z
    ],
                                       order=1,
                                       preserve_range=True,
                                       multichannel=True,
                                       mode='constant')
    label_onehot_downsampled = utils.crop_or_pad_volume_to_size_along_x_1hot(
        label_onehot_downsampled, int(256 / exp_config.downsampling_factor_x))

    # ============================
    # Initialize step number - this is number of mini-batch runs
    # ============================
    init_step = 0

    # ============================
    # if continue_run is set to True, load the model parameters saved earlier
    # else start training from scratch
    # ============================
    if continue_run:
        logging.info(
            '============================================================')
        logging.info('Continuing previous run')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'models/model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1])
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0
        logging.info(
            '============================================================')

    # ================================================================
    # reset the graph built so far and build a new TF graph
    # ================================================================
    tf.reset_default_graph()
    with tf.Graph().as_default():

        # ============================
        # set random seed for reproducibility
        # ============================
        tf.random.set_random_seed(exp_config.run_number)
        np.random.seed(exp_config.run_number)

        # ================================================================
        # create placeholders - segmentation net
        # ================================================================
        images_pl = tf.placeholder(tf.float32,
                                   shape=[exp_config.batch_size] +
                                   list(exp_config.image_size) + [1],
                                   name='images')
        learning_rate_pl = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='learning_rate')
        training_pl = tf.placeholder(tf.bool,
                                     shape=[],
                                     name='training_or_testing')

        # ================================================================
        # insert a normalization module in front of the segmentation network
        # the normalization module is trained for each test image
        # ================================================================
        images_normalized, added_residual = model.normalize(
            images_pl, exp_config, training_pl)

        # ================================================================
        # build the graph that computes predictions from the inference model
        # By setting the 'training_pl' to false directly, the update ops for the moments in the BN layer are not created at all.
        # This allows grouping the update ops together with the optimizer training, while training the normalizer - in case the normalizer has BN.
        # ================================================================
        predicted_seg_logits, predicted_seg_softmax, predicted_seg = model.predict_i2l(
            images_normalized,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # 3d prior
        # ================================================================
        labels_3d_1hot_shape = [1] + list(
            exp_config.image_size_downsampled) + [exp_config.nlabels]
        # predict the current segmentation for the entire volume, downsample it and pass it through this placeholder
        predicted_seg_1hot_3d_pl = tf.placeholder(tf.float32,
                                                  shape=labels_3d_1hot_shape,
                                                  name='predicted_labels_3d')

        # denoise the noisy segmentation
        _, predicted_seg_softmax_3d_noisy_autoencoded_softmax, _ = model.predict_l2l(
            predicted_seg_1hot_3d_pl,
            exp_config,
            training_pl=tf.constant(False, dtype=tf.bool))

        # ================================================================
        # divide the vars into segmentation network and normalization network
        # ================================================================
        i2l_vars = []
        l2l_vars = []
        normalization_vars = []

        for v in tf.global_variables():
            var_name = v.name
            if 'image_normalizer' in var_name:
                normalization_vars.append(v)
                i2l_vars.append(
                    v
                )  # the normalization vars also need to be restored from the pre-trained i2l mapper
            elif 'i2l_mapper' in var_name:
                i2l_vars.append(v)
            elif 'l2l_mapper' in var_name:
                l2l_vars.append(v)

        # ================================================================
        # Make a list of trainable i2l vars. This will be used to compute gradients.
        # The other list contains trainable as well as non-trainable parameters. This is required for saving and loading all parameters, but runs into trouble when gradients are asked to be computed for non-trainable parameters
        # ================================================================
        i2l_vars_trainable = []
        for v in i2l_vars:
            if v.trainable is True:
                i2l_vars_trainable.append(v)

        # ================================================================
        # add ops for calculation of the prior loss - wrt an atlas or the outputs of the DAE
        # ================================================================
        prior_label_1hot_pl = tf.placeholder(
            tf.float32,
            shape=[exp_config.batch_size_downsampled] + list(
                (exp_config.image_size_downsampled[1],
                 exp_config.image_size_downsampled[2])) + [exp_config.nlabels],
            name='labels_prior')

        # down sample the predicted logits
        predicted_seg_logits_expanded = tf.expand_dims(predicted_seg_logits,
                                                       axis=0)
        # the 'upsample' function will actually downsample the predictions, as the scaling factors have been set appropriately
        predicted_seg_logits_downsampled = layers.bilinear_upsample3D_(
            predicted_seg_logits_expanded,
            name='downsampled_predictions',
            factor_x=1 / exp_config.downsampling_factor_x,
            factor_y=1 / exp_config.downsampling_factor_y,
            factor_z=1 / exp_config.downsampling_factor_z)
        predicted_seg_logits_downsampled = tf.squeeze(
            predicted_seg_logits_downsampled
        )  # the first axis was added only for the downsampling in 3d

        # compute the dice between the predictions and the prior in the downsampled space
        loss_op = model.loss(logits=predicted_seg_logits_downsampled,
                             labels=prior_label_1hot_pl,
                             nlabels=exp_config.nlabels,
                             loss_type=exp_config.loss_type_prior,
                             mask_for_loss_within_mask=None,
                             are_labels_1hot=True)

        tf.summary.scalar('tr_losses/loss', loss_op)

        # ================================================================
        # one of the two prior losses will be used in the following manner:
        # the atlas prior will be used when the current prediction is deemed to be very far away from a reasonable solution
        # once a reasonable solution is reached, the dae prior will be used.
        # these 3d computations will be done outside the graph and will be passed via placeholders for logging in tensorboard
        # ================================================================
        lambda_prior_atlas_pl = tf.placeholder(tf.float32,
                                               shape=[],
                                               name='lambda_prior_atlas')
        lambda_prior_dae_pl = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='lambda_prior_dae')
        tf.summary.scalar('lambdas/prior_atlas', lambda_prior_atlas_pl)
        tf.summary.scalar('lambdas/prior_dae', lambda_prior_dae_pl)

        dice3d_prior_atlas_pl = tf.placeholder(tf.float32,
                                               shape=[],
                                               name='dice3d_prior_atlas')
        dice3d_prior_dae_pl = tf.placeholder(tf.float32,
                                             shape=[],
                                             name='dice3d_prior_dae')
        dice3d_gt_pl = tf.placeholder(tf.float32, shape=[], name='dice3d_gt')
        tf.summary.scalar('dice3d/prior_atlas', dice3d_prior_atlas_pl)
        tf.summary.scalar('dice3d/prior_dae', dice3d_prior_dae_pl)
        tf.summary.scalar('dice3d/gt', dice3d_gt_pl)

        # ================================================================
        # add optimization ops
        # ================================================================
        if exp_config.debug: print('creating training op...')

        # create an instance of the required optimizer
        optimizer = exp_config.optimizer_handle(learning_rate=learning_rate_pl)

        # initialize variable holding the accumlated gradients and create a zero-initialisation op
        accumulated_gradients = [
            tf.Variable(tf.zeros_like(var.initialized_value()),
                        trainable=False) for var in i2l_vars_trainable
        ]

        # accumulated gradients init op
        accumulated_gradients_zero_op = [
            ac.assign(tf.zeros_like(ac)) for ac in accumulated_gradients
        ]

        # calculate gradients and define accumulation op
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            gradients = optimizer.compute_gradients(
                loss_op, var_list=i2l_vars_trainable)
            # compute_gradients return a list of (gradient, variable) pairs.
        accumulate_gradients_op = [
            ac.assign_add(gg[0])
            for ac, gg in zip(accumulated_gradients, gradients)
        ]

        # define the gradient mean op
        num_accumulation_steps_pl = tf.placeholder(
            dtype=tf.float32, name='num_accumulation_steps')
        accumulated_gradients_mean_op = [
            ag.assign(tf.divide(ag, num_accumulation_steps_pl))
            for ag in accumulated_gradients
        ]

        # reassemble the gradients in the [value, var] format and do define train op
        final_gradients = [(ag, gg[1])
                           for ag, gg in zip(accumulated_gradients, gradients)]
        train_op = optimizer.apply_gradients(final_gradients)

        # ================================================================
        # sequence of running opt ops:
        # 1. at the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders)
        # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs
        # 3. at the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl
        # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients.
        # ================================================================
        # ================================================================
        # previous train_op without accumulation of gradients
        # ================================================================
        # train_op = model.training_step(loss_op, normalization_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars = True)

        # ================================================================
        # build the summary Tensor based on the TF collection of Summaries.
        # ================================================================
        if exp_config.debug: print('creating summary op...')

        # ================================================================
        # add init ops
        # ================================================================
        init_ops = tf.global_variables_initializer()

        # ================================================================
        # find if any vars are uninitialized
        # ================================================================
        if exp_config.debug:
            logging.info(
                'Adding the op to get a list of initialized variables...')
        uninit_vars = tf.report_uninitialized_variables()

        # ================================================================
        # create session
        # ================================================================
        sess = tf.Session()

        # ================================================================
        # create a summary writer
        # ================================================================
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # ================================================================
        # summaries of the training errors
        # ================================================================
        prior_dae_dice = tf.placeholder(tf.float32,
                                        shape=[],
                                        name='prior_dae_dice')
        prior_dae_dice_summary = tf.summary.scalar('test_img/prior_dae_dice',
                                                   prior_dae_dice)
        prior_dae_output_dice_wrt_gt = tf.placeholder(
            tf.float32, shape=[], name='prior_dae_output_dice_wrt_gt')
        prior_dae_output_dice_wrt_gt_summary = tf.summary.scalar(
            'test_img/prior_dae_output_dice_wrt_gt',
            prior_dae_output_dice_wrt_gt)
        prior_atlas_dice = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='prior_atlas_dice')
        prior_atlas_dice_summary = tf.summary.scalar(
            'test_img/prior_atlas_dice', prior_atlas_dice)
        prior_dae_atlas_dice_ratio = tf.placeholder(
            tf.float32, shape=[], name='prior_dae_atlas_dice_ratio')
        prior_dae_atlas_dice_ratio_summary = tf.summary.scalar(
            'test_img/prior_dae_atlas_dice_ratio', prior_dae_atlas_dice_ratio)
        prior_dice = tf.placeholder(tf.float32, shape=[], name='prior_dice')
        prior_dice_summary = tf.summary.scalar('test_img/prior_dice',
                                               prior_dice)
        gt_dice = tf.placeholder(tf.float32, shape=[], name='gt_dice')
        gt_dice_summary = tf.summary.scalar('test_img/gt_dice', gt_dice)

        # ================================================================
        # create savers
        # ================================================================
        saver_i2l = tf.train.Saver(var_list=i2l_vars)
        saver_l2l = tf.train.Saver(var_list=l2l_vars)
        saver_test_data = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)
        saver_best_loss = tf.train.Saver(var_list=i2l_vars, max_to_keep=3)

        # ================================================================
        # add operations to compute dice between two 3d volumes
        # ================================================================
        pred_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='pred_3d')
        labl_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='labl_3d')
        atls_3d_1hot_pl = tf.placeholder(
            tf.float32,
            shape=list(exp_config.image_size_downsampled) +
            [exp_config.nlabels],
            name='atls_3d')

        dice_3d_op_dae = losses.compute_dice_3d_without_batch_axis(
            prediction=pred_3d_1hot_pl, labels=labl_3d_1hot_pl)
        dice_3d_op_atlas = losses.compute_dice_3d_without_batch_axis(
            prediction=pred_3d_1hot_pl, labels=atls_3d_1hot_pl)

        # ================================================================
        # freeze the graph before execution
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('Freezing the graph now!')
        tf.get_default_graph().finalize()

        # ================================================================
        # Run the Op to initialize the variables.
        # ================================================================
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('initializing all variables...')
        sess.run(init_ops)

        # ================================================================
        # print names of uninitialized variables
        # ================================================================
        uninit_variables = sess.run(uninit_vars)
        if exp_config.debug:
            logging.info(
                '============================================================')
            logging.info('This is the list of uninitialized variables:')
            for v in uninit_variables:
                print(v)

        # ================================================================
        # Restore the segmentation network parameters and the pre-trained i2i mapper parameters
        # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_i2l.restore(sess, checkpoint_path)

        # ================================================================
        # Restore the prior network parameters
        # ================================================================
        logging.info(
            '============================================================')
        path_to_model = sys_config.log_root + 'l2l_mapper/' + exp_config.expname_l2l + '/models/'
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            path_to_model, 'best_dice.ckpt')
        logging.info('Restoring the trained parameters from %s...' %
                     checkpoint_path)
        saver_l2l.restore(sess, checkpoint_path)

        # ================================================================
        # After the adaptation for the 1st TD subject is done, start the adaptation for the subsequent subjects with those parameters
        # ================================================================
        if log_dir_first_TD_subject is not '':
            logging.info(
                '============================================================')
            path_to_model = log_dir_first_TD_subject + '/models/'
            checkpoint_path = utils.get_latest_model_checkpoint_path(
                path_to_model, 'best_score.ckpt')
            logging.info('Restoring the trained parameters from %s...' %
                         checkpoint_path)
            saver_test_data.restore(sess, checkpoint_path)
            max_steps_tta = exp_config.max_steps
        else:
            max_steps_tta = 5 * exp_config.max_steps  # run the adaptation of the 1st TD subject for longer

        # ================================================================
        # continue run from a saved checkpoint
        # ================================================================
        if continue_run:
            # Restore session
            logging.info(
                '============================================================')
            logging.info('Restroring normalization module from: %s' %
                         init_checkpoint_path)
            saver_test_data.restore(sess, init_checkpoint_path)

        # ================================================================
        # run training epochs
        # ================================================================
        step = init_step
        best_score = 0.0

        while (step < max_steps_tta):

            # ================================================
            # After every some epochs,
            # Get the prediction for the entire volume, evaluate it using the DAE.
            # Now, decide whether to use the DAE output or the atlas as the ground truth for the next update.
            # ================================================
            if (step == init_step) or (step % exp_config.check_ood_frequency is
                                       0):

                # ==================
                # 1. compute the current 3d segmentation prediction
                # ==================
                y_pred_soft = []
                for batch in iterate_minibatches_images(
                        image, batch_size=exp_config.batch_size):
                    y_pred_soft.append(
                        sess.run(predicted_seg_softmax,
                                 feed_dict={
                                     images_pl: batch,
                                     training_pl: False
                                 }))
                y_pred_soft = np.squeeze(np.array(y_pred_soft)).astype(float)
                y_pred_soft = np.reshape(y_pred_soft, [
                    -1, y_pred_soft.shape[2], y_pred_soft.shape[3],
                    y_pred_soft.shape[4]
                ])

                # ==================
                # 2. downsample it. Let's call this guy 'A'
                # ==================
                y_pred_soft_downsampled = rescale(y_pred_soft, [
                    1 / exp_config.downsampling_factor_x,
                    1 / exp_config.downsampling_factor_y,
                    1 / exp_config.downsampling_factor_z
                ],
                                                  order=1,
                                                  preserve_range=True,
                                                  multichannel=True,
                                                  mode='constant').astype(
                                                      np.float32)

                # ==================
                # 3. pass the downsampled prediction through the DAE and get its output 'B'
                # ==================
                feed_dict = {
                    predicted_seg_1hot_3d_pl:
                    np.expand_dims(y_pred_soft_downsampled, axis=0)
                }
                y_pred_noisy_denoised_softmax = np.squeeze(
                    sess.run(
                        predicted_seg_softmax_3d_noisy_autoencoded_softmax,
                        feed_dict=feed_dict)).astype(np.float16)
                y_pred_noisy_denoised = np.argmax(
                    y_pred_noisy_denoised_softmax, axis=-1)

                # ==================
                # 4. compute the dice between:
                #       a. 'A' (seg network prediction downsampled) and 'B' (dae network output)
                #       b. 'B' (dae network output) and downsampled gt labels (for debugging, to see if the dae output is close to the gt.)
                #       c. 'A' (seg network prediction downsampled) and 'C' (downsampled atlas)
                # ==================
                dAB = sess.run(dice_3d_op_dae,
                               feed_dict={
                                   pred_3d_1hot_pl: y_pred_soft_downsampled,
                                   labl_3d_1hot_pl:
                                   y_pred_noisy_denoised_softmax
                               })
                dBgt = sess.run(dice_3d_op_dae,
                                feed_dict={
                                    pred_3d_1hot_pl:
                                    y_pred_noisy_denoised_softmax,
                                    labl_3d_1hot_pl: label_onehot_downsampled
                                })
                dAC = sess.run(dice_3d_op_atlas,
                               feed_dict={
                                   pred_3d_1hot_pl: y_pred_soft_downsampled,
                                   atls_3d_1hot_pl: atlas_downsampled
                               })

                # ==================
                # 5. compute the ratio dice(AB) / dice(AC). pass the ratio through a threshold and decide whether to use the DAE or the atlas as the prior
                # ==================
                ratio_dice = dAB / (dAC + 1e-5)

                if exp_config.use_gt_for_tta is True:
                    target_labels_for_this_epoch = label_onehot_downsampled
                    prr = dBgt
                elif (ratio_dice > exp_config.dae_atlas_ratio_threshold) and (
                        dAC > exp_config.min_atlas_dice):
                    target_labels_for_this_epoch = y_pred_noisy_denoised_softmax
                    prr = dAB
                else:
                    target_labels_for_this_epoch = atlas_downsampled
                    prr = dAC

                # ==================
                # update losses on tensorboard
                # ==================
                summary_writer.add_summary(
                    sess.run(prior_dae_dice_summary,
                             feed_dict={prior_dae_dice: dAB}), step)
                summary_writer.add_summary(
                    sess.run(prior_dae_output_dice_wrt_gt_summary,
                             feed_dict={prior_dae_output_dice_wrt_gt: dBgt}),
                    step)
                summary_writer.add_summary(
                    sess.run(prior_atlas_dice_summary,
                             feed_dict={prior_atlas_dice: dAC}), step)
                summary_writer.add_summary(
                    sess.run(
                        prior_dae_atlas_dice_ratio_summary,
                        feed_dict={prior_dae_atlas_dice_ratio: ratio_dice}),
                    step)
                summary_writer.add_summary(
                    sess.run(prior_dice_summary, feed_dict={prior_dice: prr}),
                    step)

                # ==================
                # save best model so far
                # ==================
                if best_score < prr:
                    best_score = prr
                    best_file = os.path.join(log_dir, 'models/best_score.ckpt')
                    saver_best_loss.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best score (%f) at step %d -  Saving model.'
                        % (best_score, step))

                # ==================
                # dice wrt gt
                # ==================
                y_pred = []
                for batch in iterate_minibatches_images(
                        image, batch_size=exp_config.batch_size):
                    y_pred.append(
                        sess.run(predicted_seg,
                                 feed_dict={
                                     images_pl: batch,
                                     training_pl: False
                                 }))
                y_pred = np.squeeze(np.array(y_pred)).astype(float)
                y_pred = np.reshape(y_pred,
                                    [-1, y_pred.shape[2], y_pred.shape[3]])
                dice_wrt_gt = met.f1_score(label.flatten(),
                                           y_pred.flatten(),
                                           average=None)
                summary_writer.add_summary(
                    sess.run(gt_dice_summary,
                             feed_dict={gt_dice: np.mean(dice_wrt_gt[1:])}),
                    step)

                # ==================
                # visualize results
                # ==================
                if step % exp_config.vis_frequency is 0:

                    # ===========================
                    # save checkpoint
                    # ===========================
                    logging.info(
                        '=============== Saving checkkpoint at step %d ... ' %
                        step)
                    checkpoint_file = os.path.join(log_dir,
                                                   'models/model.ckpt')
                    saver_test_data.save(sess,
                                         checkpoint_file,
                                         global_step=step)

                    y_pred_noisy_denoised_upscaled = utils.crop_or_pad_volume_to_size_along_x(
                        rescale(y_pred_noisy_denoised, [
                            exp_config.downsampling_factor_x,
                            exp_config.downsampling_factor_y,
                            exp_config.downsampling_factor_z
                        ],
                                order=0,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant'),
                        image.shape[0]).astype(np.uint8)

                    x_norm = []
                    for batch in iterate_minibatches_images(
                            image, batch_size=exp_config.batch_size):
                        x = batch
                        x_norm.append(
                            sess.run(images_normalized,
                                     feed_dict={
                                         images_pl: x,
                                         training_pl: False
                                     }))
                    x_norm = np.squeeze(np.array(x_norm)).astype(float)
                    x_norm = np.reshape(x_norm,
                                        [-1, x_norm.shape[2], x_norm.shape[3]])

                    utils_vis.save_sample_results(
                        x=image,
                        x_norm=x_norm,
                        x_diff=x_norm - image,
                        y=y_pred,
                        y_pred_dae=y_pred_noisy_denoised_upscaled,
                        at=np.argmax(atlas, axis=-1),
                        gt=label,
                        savepath=log_dir + '/results/visualize_images/step' +
                        str(step) + '.png')

            # ================================================
            # Part of training ops sequence:
            # 1. At the start of each epoch, run accumulated_gradients_zero_op (no need to provide values for any placeholders)
            # ================================================
            sess.run(accumulated_gradients_zero_op)
            num_accumulation_steps = 0

            # ================================================
            # batches
            # ================================================
            for batch in iterate_minibatches_images_and_downsampled_labels(
                    images=image,
                    batch_size=exp_config.batch_size,
                    labels_downsampled=target_labels_for_this_epoch,
                    batch_size_downsampled=exp_config.batch_size_downsampled):

                x, y = batch

                # ===========================
                # define feed dict for this iteration
                # ===========================
                feed_dict = {
                    images_pl: x,
                    prior_label_1hot_pl: y,
                    learning_rate_pl: exp_config.learning_rate,
                    training_pl: True
                }

                # ================================================
                # Part of training ops sequence:
                # 2. in each training iteration, run accumulate_gradients_op with regular feed dict of inputs and outputs
                # ================================================
                sess.run(accumulate_gradients_op, feed_dict=feed_dict)
                num_accumulation_steps = num_accumulation_steps + 1

                step += 1

            # ================================================
            # Part of training ops sequence:
            # 3. At the end of the epoch (after all batches of the volume have been passed), run accumulated_gradients_mean_op, with a value for the placeholder num_accumulation_steps_pl
            # ================================================
            sess.run(
                accumulated_gradients_mean_op,
                feed_dict={num_accumulation_steps_pl: num_accumulation_steps})

            # ================================================================
            # sequence of running opt ops:
            # 4. finally, run the train_op. this also requires input output placeholders, as compute_gradients will be called again, but the returned gradient values will be replaced by the mean gradients.
            # ================================================================
            sess.run(train_op, feed_dict=feed_dict)

        # ================================================================
        # ================================================================
        sess.close()

    # ================================================================
    # ================================================================
    gc.collect()

    return 0
Beispiel #6
0
# stochastic gradient descent
batchsize = 100
trainloss = []
validloss = []
snapshot = []

for i in range(n_iter):
    # 每一轮迭代前,产生一组新的序号(目的在于置乱数据)
    idxs = np.random.permutation(trainX.shape[0])

    for j in range(0, trainX.shape[0], batchsize):
        batchX = trainX[idxs[j:j + batchsize]]

        # 任务3:实现utils中 make_onehot 函数

        batchy = utils.make_onehot(trainy[idxs[j:j + batchsize]], 10)
        # 数据的前馈(feed forward)和误差的反向传播(back propagation)
        # 是人工神经网络中的两种数据流向,这里用 forward 和 backward 为下列
        # 两个方法命名是与其它大型机器学习框架对人工神经网络中有关方法的命名保持一致

        y_hat = model.forward(batchX)

        # 任务4:理解utils中cross_entropy的实现代码

        loss1 = utils.cross_entropy(y_hat, batchy)
        trainloss.append(loss1)
        error = y_hat - batchy
        model.backward(error)

        # 评估模型性能
        loss2 = utils.cross_entropy(model.forward(validX),
Beispiel #7
0
def conversion(args, net, device='cuda'):
    assert os.path.isdir(args.data_dir), 'Cannot found data dir : {}'.format(
        args.data_dir)

    all_spk_path = [
        p for p in glob.glob(os.path.join(args.data_dir, '*'))
        if os.path.isdir(p)
    ]
    all_test_samples = [
        glob.glob(os.path.join(p, 'test', '*.npz'))[0] for p in all_spk_path
    ]
    os.makedirs(args.out_dir, exist_ok=True)

    all_pair = itertools.product(all_test_samples, all_test_samples)
    for src, trg in tqdm(all_pair, desc="converting voices"):
        src_name = src.split('/')[-3]
        trg_name = trg.split('/')[-3]
        src_npz = np.load(src)
        trg_npz = np.load(trg)

        x = src_npz['mel']
        p = src_npz['f0'][:, np.newaxis]
        emb_src_np = make_onehot(src_npz['spk_label'].item(),
                                 hparams.n_speakers)
        emb_trg_np = make_onehot(trg_npz['spk_label'].item(),
                                 hparams.n_speakers)

        x_padded, pad_len = pad_seq(x, base=hparams.freq, constant_values=None)
        p_padded, pad_len = pad_seq(p,
                                    base=hparams.freq,
                                    constant_values=-1e10)

        quantized_p, _ = quantize_f0_numpy(p_padded[:, 0],
                                           num_bins=hparams.pitch_bin)

        x_src = torch.from_numpy(x_padded).unsqueeze(0).to(device)
        p_src = torch.from_numpy(quantized_p).unsqueeze(0).to(device)
        emb_src = torch.from_numpy(emb_src_np).unsqueeze(0).to(device)
        emb_trg = torch.from_numpy(emb_trg_np).unsqueeze(0).to(device)

        if args.model == 'autovc':
            out, out_psnt, _ = net(x_src, emb_src, emb_trg)
        elif args.model == 'autovc-f0':
            out, out_psnt, _ = net(x_src, p_src, emb_src, emb_trg)
        else:
            print("Wrong model name : {}".format(args.model))

        print(out_psnt)

        if pad_len == 0:
            out_mel = out_psnt.squeeze().detach().cpu().numpy()[:, :]
        else:
            out_mel = out_psnt.squeeze().detach().cpu().numpy()[:-pad_len, :]
        src_mel = src_npz['mel']
        trg_mel = trg_npz['mel']

        np.save(
            os.path.join(
                args.out_dir, '{}-{}-feats.npy'.format(
                    src_name,
                    os.path.splitext(src.split('/')[-1])[0])), src_mel)
        np.save(
            os.path.join(
                args.out_dir, '{}-{}-feats.npy'.format(
                    trg_name,
                    os.path.splitext(trg.split('/')[-1])[0])), trg_mel)
        np.save(
            os.path.join(
                args.out_dir, '{}-to-{}-{}.npy'.format(
                    src_name, trg_name,
                    os.path.splitext(src.split('/')[-1])[0])), out_mel)
Beispiel #8
0
def plot_multiplicity(model, motif, outfile, no_of_repeats):
    sequences = random_sequences(no_of_repeats)
    # Generate 10000 random sequences:
    C = np.zeros(shape=(1, 130))
    data = np.zeros(shape=(no_of_repeats, 6))
    for seq_id, sequence in enumerate(sequences):
        # Iterate over the 10000 sequences.
        idx = 0
        locations = np.random.randint(
            5, 495, 1)  # Putting it in a single location for now
        for loc in locations:
            # Note: The 5 to 495 is to make sure that I don't go out
            # of sequence bounds while inserting sequence
            # Append 1, 2, 3, 4 and 5  motifs
            # Check the baseline score of the generated sequence
            sequence_onehot = make_onehot(sequence, 500)
            # 1. Check score with no motif
            data[seq_id,
                 idx] = get_embeddings_low_mem(model,
                                               (sequence_onehot, C))[0][0]
            idx = idx + 1
            sequence[loc - 3:loc + 3] = list(motif)
            sequence = list(sequence)
            sequence_onehot = make_onehot(sequence, 500)
            data[seq_id,
                 idx] = get_embeddings_low_mem(model,
                                               (sequence_onehot, C))[0][0]
            idx = idx + 1
            while True:
                if idx == 6:
                    break
                else:
                    print idx
                    # Add motifs and append to lists
                    offset = np.random.randint(5, 500)  # Add atleast 5
                    # to the left edge of the motif
                    if 5 < idx + offset < 495:
                        curr = idx + offset
                        sequence[curr - 3:curr + 3] = list(motif)
                        sequence = list(sequence)
                        sequence_onehot = make_onehot(sequence, 500)
                        data[seq_id, idx] = get_embeddings_low_mem(
                            model, (sequence_onehot, C))[0][0]
                        idx += 1
                    else:
                        pass

    # Converting to a tidy data
    data = pd.DataFrame(data)
    data = pd.melt(data)
    # Plotting
    fig, ax = plt.subplots()
    fig.subplots_adjust(left=.15, bottom=.15, right=.95, top=.95)
    sns.boxplot(data['variable'], data['value'], color='#D2B4DE')
    for axis in ['top', 'bottom', 'left', 'right']:
        ax.spines[axis].set_linewidth(1.5)
    plt.xticks(range(6), [0, 1, 2, 3, 4, 5])
    plt.xlabel('Number of embedded CAGSTG motifs in simulated data',
               fontsize=10)
    plt.ylabel('Sequence sub-network activations', fontsize=10)
    fig.set_size_inches(3.5, 4)
    plt.savefig(outfile)
Beispiel #9
0
def iterate_minibatches(labels, batch_size, train_or_eval='train'):

    # ===========================
    # generate indices to randomly select subjects in each minibatch
    # ===========================
    n_labels = labels.shape[0]
    random_indices = np.random.permutation(n_labels)

    # ===========================
    # using only a fraction of the batches in each epoch
    # ===========================
    for b_i in range(n_labels // batch_size):

        if b_i + batch_size > n_labels:
            continue

        batch_indices = np.sort(random_indices[b_i * batch_size:(b_i + 1) *
                                               batch_size])

        labels_this_batch = labels[batch_indices, ...]

        # ===========================
        # data augmentation (random elastic transformations, translations, rotations, scaling)
        # doing data aug both during training as well as during evaluation on the validation set (used for model selection)
        # ===========================
        if train_or_eval is 'train' or train_or_eval is 'eval':
            # if train_or_eval is 'train':
            labels_this_batch = do_data_augmentation(
                labels=labels_this_batch,
                data_aug_ratio=exp_config.da_ratio,
                sigma=exp_config.sigma,
                alpha=exp_config.alpha,
                trans_min=exp_config.trans_min,
                trans_max=exp_config.trans_max,
                rot_min=exp_config.rot_min,
                rot_max=exp_config.rot_max,
                scale_min=exp_config.scale_min,
                scale_max=exp_config.scale_max)

        # ==================
        # make labels 1-hot
        # ==================
        labels_this_batch_1hot = utils.make_onehot(labels_this_batch,
                                                   exp_config.nlabels)

        # ===========================
        # make noise masks that the autoencoder with try to denoise
        # ===========================
        if train_or_eval is 'train':
            blank_masks_this_batch, wrong_labels_this_batch = utils_masks.make_noise_masks_3d(
                shape=[exp_config.batch_size] + list(exp_config.image_size) +
                [exp_config.nlabels],
                mask_type=exp_config.mask_type,
                mask_params=[exp_config.mask_radius, exp_config.num_squares],
                nlabels=exp_config.nlabels,
                labels_1hot=labels_this_batch_1hot)

        elif train_or_eval is 'eval':
            # fixing amount of noise in order to get comparable runs during evaluation
            blank_masks_this_batch, wrong_labels_this_batch = utils_masks.make_noise_masks_3d(
                shape=[exp_config.batch_size] + list(exp_config.image_size) +
                [exp_config.nlabels],
                mask_type=exp_config.mask_type,
                mask_params=[exp_config.mask_radius, exp_config.num_squares],
                nlabels=exp_config.nlabels,
                labels_1hot=labels_this_batch_1hot,
                is_num_masks_fixed=True,
                is_size_masks_fixed=True)

        yield labels_this_batch, blank_masks_this_batch, wrong_labels_this_batch
Beispiel #10
0
def prepare_data(input_folder, output_file, idx_start, idx_end, protocol, size,
                 target_resolution, preprocessing_folder):

    # ========================
    # read the filenames
    # ========================
    filenames = sorted(glob.glob(input_folder + '*.zip'))
    logging.info('Number of images in the dataset: %s' % str(len(filenames)))

    # =======================
    # create a hdf5 file
    # =======================
    hdf5_file = h5py.File(output_file, "w")

    # ===============================
    # Create datasets for images and labels
    # ===============================
    data = {}
    num_subjects = idx_end - idx_start

    data['images'] = hdf5_file.create_dataset("images",
                                              [num_subjects] + list(size),
                                              dtype=np.float32)
    data['labels'] = hdf5_file.create_dataset("labels",
                                              [num_subjects] + list(size),
                                              dtype=np.uint8)

    # ===============================
    # initialize lists
    # ===============================
    label_list = []
    image_list = []
    nx_list = []
    ny_list = []
    nz_list = []
    px_list = []
    py_list = []
    pz_list = []
    pat_names_list = []

    # ===============================
    # initiate counter
    # ===============================
    patient_counter = 0

    # ===============================
    # iterate through the requested indices
    # ===============================
    for idx in range(idx_start, idx_end):

        # ==================
        # get file paths
        # ==================
        patient_name, image_path, label_path = get_image_and_label_paths(
            filenames[idx], protocol, preprocessing_folder)

        # ============
        # read the image and normalize it to be between 0 and 1
        # ============
        image, _, image_hdr = utils.load_nii(image_path)

        # ==================
        # read the label file
        # ==================
        label, _, _ = utils.load_nii(label_path)
        label = utils.group_segmentation_classes(
            label)  # group the segmentation classes as required

        # ==================
        # collect some header info.
        # ==================
        px_list.append(float(image_hdr.get_zooms()[0]))
        py_list.append(float(image_hdr.get_zooms()[1]))
        pz_list.append(float(image_hdr.get_zooms()[2]))
        nx_list.append(image.shape[0])
        ny_list.append(image.shape[1])
        nz_list.append(image.shape[2])
        pat_names_list.append(patient_name)

        print(image.shape)
        print(label.shape)

        # ==================
        # crop volume along all axes from the ends (as there are several zeros towards the ends)
        # ==================
        image = utils.crop_or_pad_volume_to_size_along_x(image, size[0])
        label = utils.crop_or_pad_volume_to_size_along_x(label, size[0])
        image = utils.crop_or_pad_volume_to_size_along_y(image, size[1])
        label = utils.crop_or_pad_volume_to_size_along_y(label, size[1])
        image = utils.crop_or_pad_volume_to_size_along_z(image, size[2])
        label = utils.crop_or_pad_volume_to_size_along_z(label, size[2])

        print(image.shape)
        print(label.shape)

        # ==================
        # normalize the image
        # ==================
        image_normalized = utils.normalise_image(image, norm_type='div_by_max')

        # ======================================================
        # rescale, crop / pad to make all images of the required size and resolution
        # ======================================================
        scale_vector = [
            image_hdr.get_zooms()[0] / target_resolution[0],
            image_hdr.get_zooms()[1] / target_resolution[1],
            image_hdr.get_zooms()[2] / target_resolution[2]
        ]

        image_rescaled = transform.rescale(image_normalized,
                                           scale_vector,
                                           order=1,
                                           preserve_range=True,
                                           multichannel=False,
                                           mode='constant')

        label_onehot = utils.make_onehot(label, nlabels=15)

        label_onehot_rescaled = transform.rescale(label_onehot,
                                                  scale_vector,
                                                  order=1,
                                                  preserve_range=True,
                                                  multichannel=True,
                                                  mode='constant')

        label_rescaled = np.argmax(label_onehot_rescaled, axis=-1)

        # ============
        # the images and labels have been rescaled to the desired resolution.
        # write them to the hdf5 file now.
        # ============
        image_list.append(image_rescaled)
        label_list.append(label_rescaled)

        # ============
        # write to file
        # ============
        _write_range_to_hdf5(data, image_list, label_list, patient_counter,
                             patient_counter + 1)

        _release_tmp_memory(image_list, label_list)

        # update counter
        patient_counter += 1

    # Write the small datasets
    hdf5_file.create_dataset('nx', data=np.asarray(nx_list, dtype=np.uint16))
    hdf5_file.create_dataset('ny', data=np.asarray(ny_list, dtype=np.uint16))
    hdf5_file.create_dataset('nz', data=np.asarray(nz_list, dtype=np.uint16))
    hdf5_file.create_dataset('px', data=np.asarray(px_list, dtype=np.float32))
    hdf5_file.create_dataset('py', data=np.asarray(py_list, dtype=np.float32))
    hdf5_file.create_dataset('pz', data=np.asarray(pz_list, dtype=np.float32))
    hdf5_file.create_dataset('patnames',
                             data=np.asarray(pat_names_list, dtype="S10"))

    # After test train loop:
    hdf5_file.close()
# ===========================
for subject_num in range(labels.shape[0]):

    label_this_subject = labels[subject_num, ...]

    # visualize the labels
    utils_vis.save_samples_downsampled(
        label_this_subject[::8, :, :], sys_config.preproc_folder_hcp +
        '/training_image_' + str(subject_num + 1) + '_for_making_atlas.png')

    # add at least one voxel of each label - so that the 1hot function outputs everything of the same shape
    label_this_subject_ = np.copy(label_this_subject)
    for j in range(15):
        label_this_subject_[j, 0, 0] = j

    label_this_subject_1hot = utils.make_onehot(label_this_subject_)
    atlas.append(label_this_subject_1hot)

# ===========================
# ===========================
atlas_mean = np.mean(np.array(atlas), axis=0)
atlas_mean = atlas_mean.astype(np.float16)
np.save(sys_config.preproc_folder_hcp + 'hcp_atlas.npy', atlas_mean)

atlas_mean_vis = (255 * atlas_mean).astype(np.uint8)
for l in range(atlas_mean_vis.shape[-1]):
    utils_vis.save_samples_downsampled(atlas_mean_vis[::8, :, :, l],
                                       sys_config.preproc_folder_hcp +
                                       '/hcp_atlas_label' + str(l) + '.png',
                                       add_pixel_each_label=False,
                                       cmap='gray')