def main():
    if opt.seed is None:
        opt.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir)

    if opt.mode == "test":
        if opt.checkpoint is None:
            raise Exception("checkpoint required for test mode")

        # load some options from the checkpoint
        options = {"ngf", "ndf", "lab_colorization"}
        with open(os.path.join(opt.checkpoint, "options.json")) as f:
            for key, val in json.loads(f.read()).items():
                if key in options:
                    print("loaded", key, "=", val)
                    setattr(opt, key, val)
        
    for k, v in opt._get_kwargs():
        print(k, "=", v)

    with open(os.path.join(opt.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(opt), sort_keys=True, indent=4))
    
    
    is_video = True
    if is_video:
        # create Video-Reader 
        roisize = 512
        xcenter, ycenter = 1080/2, 1920/2
        VideoReader = data.VideoReader(opt.input_dir, opt.scale_size, roisize, xcenter, ycenter)
        examples = VideoReader.loadDummy() # bad workaround

    else:
        examples = data.load_examples(opt.input_dir, opt.scale_size, opt.batch_size, opt.mode)
       
    
        
    print("examples count = %d" % examples.count)

    # inputs and targets are [batch_size, height, width, channels]
    C2Pmodel = model.create_model(examples.inputs, examples.targets, opt.ndf, opt.ngf, EPS, opt.gan_weight, opt.l1_weight, opt.lr, opt.beta1)

     # reverse any processing on images so they can be written to disk or displayed to user
    inputs = data.deprocess(examples.inputs)
    targets = data.deprocess(examples.targets)
    outputs = data.deprocess(C2Pmodel.outputs)
    outputs_psf = data.deprocess(C2Pmodel.outputs_psf)

    def convert(image):
        return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)

    with tf.name_scope("convert_inputs"):
        converted_inputs = convert(inputs)

    with tf.name_scope("convert_targets"):
        converted_targets = convert(targets)

    with tf.name_scope("convert_outputs"):
        converted_outputs = convert(outputs)

    with tf.name_scope("convert_outputspsf"):
        converted_outputs_psf = convert(outputs_psf)


    with tf.name_scope("encode_images"):
        display_fetches = {
            "paths": examples.paths,
            "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
            "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
            "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
            "outputs_psf": tf.map_fn(tf.image.encode_png, converted_outputs_psf, dtype=tf.string, name="outputpsf_pngs"),
        }

    # summaries
    with tf.name_scope("inputs_summary"):
        tf.summary.image("inputs", converted_inputs)

    with tf.name_scope("targets_summary"):
        tf.summary.image("targets", converted_targets)

    with tf.name_scope("outputs_summary"):
        tf.summary.image("outputs", converted_outputs)

    with tf.name_scope("outputspsf_summary"):
        tf.summary.image("outputs_psf", converted_outputs_psf)

    with tf.name_scope("predict_real_summary"):
        tf.summary.image("predict_real", tf.image.convert_image_dtype(C2Pmodel.predict_real, dtype=tf.uint8))

    with tf.name_scope("predict_fake_summary"):
        tf.summary.image("predict_fake", tf.image.convert_image_dtype(C2Pmodel.predict_fake, dtype=tf.uint8))

    tf.summary.scalar("discriminator_loss", C2Pmodel.discrim_loss)
    tf.summary.scalar("generator_loss_GAN", C2Pmodel.gen_loss_GAN)
    tf.summary.scalar("generator_loss_L1", C2Pmodel.gen_loss_L1)

    # add histogramm summary for all trainable values
    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name + "/values", var)
    
    # add histogramm summary for gradients
    for grad, var in C2Pmodel.discrim_grads_and_vars + C2Pmodel.gen_grads_and_vars:
        tf.summary.histogram(var.op.name + "/gradients", grad)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=1)

    # initiate the logdir for the Tensorboard logging
    logdir = opt.output_dir if (opt.trace_freq > 0 or opt.summary_freq > 0) else None
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        if opt.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(opt.checkpoint)
            saver.restore(sess, checkpoint)

        max_steps = 2**32
        if opt.max_epochs is not None:
            max_steps = examples.steps_per_epoch * opt.max_epochs
        if opt.max_steps is not None:
            max_steps = opt.max_steps
            
        if is_video:
            max_steps = VideoReader.__len__()

        if opt.mode == "test":
            # testing
            # at most, process the test data once
            start = time.time()
            experiment_name = opt.input_dir.split("/")[-2]
            network_name =  opt.checkpoint
            
            for step in range(max_steps):
                
                if is_video == True:
                    input_frame = VideoReader.__getitem__(step)
                    

                # evaluate result for one frame at a time
                outputs_np, outputs_psf_np = sess.run([outputs, outputs_psf], feed_dict= {inputs : input_frame})
                # hacky workaround to keep model as is
                outputs_np = np.squeeze(np.array(outputs_np))
                inputs_np =  np.squeeze(np.array(input_frame))
                outputs_psf_np = np.squeeze(np.array(outputs_psf_np))
                
                # Deprocess
                outputs_np = (outputs_np + 1) / 2
                inputs_np = (inputs_np + 1) / 2
                outputs_psf_np = (outputs_psf_np + 1) / 2
                
                # save frames to TIF 
                data.save_as_tif(inputs_np, outputs_np, outputs_psf_np, experiment_name, network_name)
                print("evaluated image " + str(step))
            print("rate", (time.time() - start) / max_steps)
        else:
            # training
            start = time.time()

            for step in range(max_steps):
                def should(freq):
                    return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)

                options = None
                run_metadata = None
                if should(opt.trace_freq):
                    options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                fetches = {
                    "train": C2Pmodel.train,
                    "global_step": sv.global_step,
                }

                if should(opt.progress_freq):
                    fetches["discrim_loss"] = C2Pmodel.discrim_loss
                    fetches["gen_loss_GAN"] = C2Pmodel.gen_loss_GAN
                    fetches["gen_loss_L1"] = C2Pmodel.gen_loss_L1

                if should(opt.summary_freq):
                    fetches["summary"] = sv.summary_op

                if should(opt.display_freq):
                    fetches["display"] = display_fetches

                results = sess.run(fetches, options=options, run_metadata=run_metadata)

                if should(opt.summary_freq):
                    print("recording summary")
                    sv.summary_writer.add_summary(results["summary"], results["global_step"])


                if should(opt.trace_freq):
                    print("recording trace")
                    sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])

                if should(opt.progress_freq):
                    # global_step will have the correct step count if we resume from a checkpoint
                    train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch)
                    train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1
                    rate = (step + 1) * opt.batch_size / (time.time() - start)
                    remaining = (max_steps - step) * opt.batch_size / rate
                    print("progress  epoch %d  step %d  image/sec %0.1f  remaining %dm" % (train_epoch, train_step, rate, remaining / 60))
                    print("discrim_loss", results["discrim_loss"])
                    print("gen_loss_GAN", results["gen_loss_GAN"])
                    print("gen_loss_L1", results["gen_loss_L1"])

                if should(opt.save_freq):
                    print("saving model")
                    saver.save(sess, os.path.join(opt.output_dir, "C2Pmodel"), global_step=sv.global_step)

                if sv.should_stop():
                    break
Exemple #2
0
        spikes = normalize_max_std_mean(spikes)

    else:
        # preprocess data 255->1->0..1 -> -1..1 #TODO: Alternativelly: Whitening?!
        patches = (2 * patches / 255.) - 1
        heatmaps = (2 * heatmaps / 255.) - 1
        spikes = (2 * spikes / 255.) - 1

    count = patches.shape[0]
    # randomize the order of the data
    # assuming data in order: [N_smaples, Width, Height, Color-channels]
    shuffle_order = np.arange(count)
    shuffle_order = np.random.shuffle(shuffle_order)

    patches = patches[shuffle_order, :, :]
    heatmaps = heatmaps[shuffle_order, :, :]
    spikes = spikes[shuffle_order, :, :]

    print('Reading finished.')

    print('Number of Training Examples: %d' % X.shape[1])

if (0):
    input_dir = '/home/diederich/Documents/STORM/DATASET_NN/04_UNPROCESSED_RAW_HW/MOV_2018_02_16_09_09_49_ISO3200_texp_1_85_lines_combined/test'
    #input_dir = '/home/useradmin/Dropbox/Dokumente/Promotion/PROJECTS/STORM/MATLAB/cellSTORM-KERAS/images/2017-12-18_18.29.45.mp4_256_Tirf_v2_from_video_v2_fakeB_2k.tif_bilinear_smallpatch'
    scale_size = 256
    batch_size = 4
    mode = 'train'
    reload(data)
    examples = data.load_examples(input_dir, scale_size, batch_size, mode)
Exemple #3
0
    if opt.y_center == -1:
        print("Please click in the center of the ROI you wish to process:")
        VideoReader.select_ROI()  # select the ROI coordinates

        # reassign it for saving the data
        opt.y_center = VideoReader.ycenter
        opt.x_center = VideoReader.xcenter

    if opt.max_steps is None:
        max_steps = VideoReader.__len__()
    else:
        max_steps = opt.max_steps

else:
    if (0):
        examples = data.load_examples(opt.input_dir, opt.scale_size,
                                      opt.batch_size, opt.mode)
    else:
        # define the path to the datafiles
        data_file_1 = './data/MOV_2018_05_09_14_15_21_ISO3200_texp_1_30_newsample.mp4_unet256_13.csv.h5'
        data_file_2 = './data/2017-12-18_18.29.45.mp4_256_Tirf_v2_from_video_v2_fakeB_2k.csv.h5'
        data_file_3 = './data/2018-01-23_18.20.10_oldSample_ISO3200_10xEypiece_texp_1_30.mp4_256.csv.h5'
        data_file_4 = './data/gt_1000k_density_2.csv.h5'

        # load each set seperatly
        print('Load Dataset #1')
        examples_1 = data.load_examples_h5(data_file_1,
                                           batch_size=opt.batch_size)
        print('Load Dataset #2')
        examples_2 = data.load_examples_h5(data_file_2,
                                           batch_size=opt.batch_size)
        print('Load Dataset #3')
Exemple #4
0
def main(_):
    if a.seed is None:
        a.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(a.seed)
    np.random.seed(a.seed)
    random.seed(a.seed)

    if not os.path.exists(a.output_dir):
        os.makedirs(a.output_dir)

    if a.mode == "export":
        if a.checkpoint is None:
            raise Exception("checkpoint required for test mode")

    print("Notice the arguments of this run is as following : ")
    for k, v in a._get_kwargs():
        print(k, "=", v)

    with open(os.path.join(a.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(a), sort_keys=True, indent=4))

    if a.mode == "export":
        batch_input = tf.placeholder("float",
                                     shape=[1, None, None, input_chanel_num],
                                     name='Input_without_preprocess')

        with tf.variable_scope("generator"):
            batch_output, _, _ = model.create_generator(batch_input, 1, a)
            output = function.deprocess(batch_output)
            real_output = tf.identity(output, name='Final_output')

        init_op = tf.global_variables_initializer()
        restore_saver = tf.train.Saver()
        export_saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(init_op)
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            restore_saver.restore(sess, checkpoint)
            print("exporting model")
            export_saver.export_meta_graph(
                filename=os.path.join(a.output_dir, "export.meta"))
            export_saver.save(sess,
                              os.path.join(a.output_dir, "export"),
                              write_meta_graph=False)

        return

    dataf0, dataf1, dataf2, labels = data.load_examples(a.input_dir)

    batch_input = tf.placeholder("float",
                                 shape=[1, None, None, input_chanel_num],
                                 name='Input')
    batch_target = tf.placeholder("float",
                                  shape=[1, None, None, 1],
                                  name='Target')
    learning_rate = tf.placeholder("float", shape=[])

    model_using = model.create_model(batch_input, batch_target, a,
                                     learning_rate)

    targets = function.deprocess(batch_target)
    outputs = function.deprocess(model_using.outputs)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=100)

    logdir = a.output_dir
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        if a.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            saver.restore(sess, checkpoint)

        start = time.time()
        logfile = open(a.output_dir + '/logfile', 'a')

        for epoch in range(a.max_epochs):
            if epoch < 21:
                lr_using = a.lr
            elif epoch < 51:
                lr_using = a.lr * 0.1
            else:
                lr_using = a.lr * 0.01

            logfile.write('Epoch ' + str(epoch) + ': learning_rate: ' +
                          str(lr_using) + '\n')
            print('Epoch' + str(epoch))
            discrim_loss_train = []
            gen_loss_GAN_train = []
            gen_loss_L1_train = []

            fetches = {
                "train": model_using.train,
                "global_step": sv.global_step,
            }
            fetches["discrim_loss"] = model_using.discrim_loss
            fetches["gen_loss_GAN"] = model_using.gen_loss_GAN
            fetches["gen_loss_L1"] = model_using.gen_loss_L1

            for index in range(training_protein_nums):
                length = dataf2[index].shape[1]
                if length < max_length:
                    input_data = np.concatenate([
                        dataf2[index],
                        np.tile(dataf1[index][np.newaxis], [length, 1, 1]),
                        np.tile(dataf1[index][:, np.newaxis], [1, length, 1]),
                        np.tile(dataf0[index][np.newaxis, np.newaxis],
                                [length, length, 1]),
                    ],
                                                axis=2)[np.newaxis]
                    label = labels[index][np.newaxis, :, :, np.newaxis]
                    print(epoch, index)
                    if index % 1000 == 0:
                        print(a.output_dir)
                    results = sess.run(fetches,
                                       feed_dict={
                                           batch_input: input_data,
                                           batch_target: label,
                                           learning_rate: lr_using
                                       })

                    discrim_loss_train.append(results["discrim_loss"])
                    gen_loss_GAN_train.append(results["gen_loss_GAN"])
                    gen_loss_L1_train.append(results["gen_loss_L1"])
            logfile.write('  trainging:\n')
            logfile.write('    discrim_loss:{}\n'.format(
                sum(discrim_loss_train) / len(discrim_loss_train)))
            logfile.write('    gen_loss_GAN:{}\n'.format(
                sum(gen_loss_GAN_train) / len(gen_loss_GAN_train)))
            logfile.write('    gen_loss_L1:{}\n'.format(
                sum(gen_loss_L1_train) / len(gen_loss_L1_train)))
            logfile.flush()

            print("saving model of epoch" + str(epoch))
            saver.save(sess, os.path.join(a.output_dir, "model" + str(epoch)))

            if sv.should_stop():
                break
        logfile.close()