def main():
    if opt.seed is None:
        opt.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir)

    if opt.mode == "test":
        if opt.checkpoint is None:
            raise Exception("checkpoint required for test mode")

        # load some options from the checkpoint
        options = {"ngf", "ndf", "lab_colorization"}
        with open(os.path.join(opt.checkpoint, "options.json")) as f:
            for key, val in json.loads(f.read()).items():
                if key in options:
                    print("loaded", key, "=", val)
                    setattr(opt, key, val)
        
    for k, v in opt._get_kwargs():
        print(k, "=", v)

    with open(os.path.join(opt.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(opt), sort_keys=True, indent=4))
    
    
    is_video = True
    if is_video:
        # create Video-Reader 
        roisize = 512
        xcenter, ycenter = 1080/2, 1920/2
        VideoReader = data.VideoReader(opt.input_dir, opt.scale_size, roisize, xcenter, ycenter)
        examples = VideoReader.loadDummy() # bad workaround

    else:
        examples = data.load_examples(opt.input_dir, opt.scale_size, opt.batch_size, opt.mode)
       
    
        
    print("examples count = %d" % examples.count)

    # inputs and targets are [batch_size, height, width, channels]
    C2Pmodel = model.create_model(examples.inputs, examples.targets, opt.ndf, opt.ngf, EPS, opt.gan_weight, opt.l1_weight, opt.lr, opt.beta1)

     # reverse any processing on images so they can be written to disk or displayed to user
    inputs = data.deprocess(examples.inputs)
    targets = data.deprocess(examples.targets)
    outputs = data.deprocess(C2Pmodel.outputs)
    outputs_psf = data.deprocess(C2Pmodel.outputs_psf)

    def convert(image):
        return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)

    with tf.name_scope("convert_inputs"):
        converted_inputs = convert(inputs)

    with tf.name_scope("convert_targets"):
        converted_targets = convert(targets)

    with tf.name_scope("convert_outputs"):
        converted_outputs = convert(outputs)

    with tf.name_scope("convert_outputspsf"):
        converted_outputs_psf = convert(outputs_psf)


    with tf.name_scope("encode_images"):
        display_fetches = {
            "paths": examples.paths,
            "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
            "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
            "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
            "outputs_psf": tf.map_fn(tf.image.encode_png, converted_outputs_psf, dtype=tf.string, name="outputpsf_pngs"),
        }

    # summaries
    with tf.name_scope("inputs_summary"):
        tf.summary.image("inputs", converted_inputs)

    with tf.name_scope("targets_summary"):
        tf.summary.image("targets", converted_targets)

    with tf.name_scope("outputs_summary"):
        tf.summary.image("outputs", converted_outputs)

    with tf.name_scope("outputspsf_summary"):
        tf.summary.image("outputs_psf", converted_outputs_psf)

    with tf.name_scope("predict_real_summary"):
        tf.summary.image("predict_real", tf.image.convert_image_dtype(C2Pmodel.predict_real, dtype=tf.uint8))

    with tf.name_scope("predict_fake_summary"):
        tf.summary.image("predict_fake", tf.image.convert_image_dtype(C2Pmodel.predict_fake, dtype=tf.uint8))

    tf.summary.scalar("discriminator_loss", C2Pmodel.discrim_loss)
    tf.summary.scalar("generator_loss_GAN", C2Pmodel.gen_loss_GAN)
    tf.summary.scalar("generator_loss_L1", C2Pmodel.gen_loss_L1)

    # add histogramm summary for all trainable values
    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name + "/values", var)
    
    # add histogramm summary for gradients
    for grad, var in C2Pmodel.discrim_grads_and_vars + C2Pmodel.gen_grads_and_vars:
        tf.summary.histogram(var.op.name + "/gradients", grad)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=1)

    # initiate the logdir for the Tensorboard logging
    logdir = opt.output_dir if (opt.trace_freq > 0 or opt.summary_freq > 0) else None
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        if opt.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(opt.checkpoint)
            saver.restore(sess, checkpoint)

        max_steps = 2**32
        if opt.max_epochs is not None:
            max_steps = examples.steps_per_epoch * opt.max_epochs
        if opt.max_steps is not None:
            max_steps = opt.max_steps
            
        if is_video:
            max_steps = VideoReader.__len__()

        if opt.mode == "test":
            # testing
            # at most, process the test data once
            start = time.time()
            experiment_name = opt.input_dir.split("/")[-2]
            network_name =  opt.checkpoint
            
            for step in range(max_steps):
                
                if is_video == True:
                    input_frame = VideoReader.__getitem__(step)
                    

                # evaluate result for one frame at a time
                outputs_np, outputs_psf_np = sess.run([outputs, outputs_psf], feed_dict= {inputs : input_frame})
                # hacky workaround to keep model as is
                outputs_np = np.squeeze(np.array(outputs_np))
                inputs_np =  np.squeeze(np.array(input_frame))
                outputs_psf_np = np.squeeze(np.array(outputs_psf_np))
                
                # Deprocess
                outputs_np = (outputs_np + 1) / 2
                inputs_np = (inputs_np + 1) / 2
                outputs_psf_np = (outputs_psf_np + 1) / 2
                
                # save frames to TIF 
                data.save_as_tif(inputs_np, outputs_np, outputs_psf_np, experiment_name, network_name)
                print("evaluated image " + str(step))
            print("rate", (time.time() - start) / max_steps)
        else:
            # training
            start = time.time()

            for step in range(max_steps):
                def should(freq):
                    return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)

                options = None
                run_metadata = None
                if should(opt.trace_freq):
                    options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                fetches = {
                    "train": C2Pmodel.train,
                    "global_step": sv.global_step,
                }

                if should(opt.progress_freq):
                    fetches["discrim_loss"] = C2Pmodel.discrim_loss
                    fetches["gen_loss_GAN"] = C2Pmodel.gen_loss_GAN
                    fetches["gen_loss_L1"] = C2Pmodel.gen_loss_L1

                if should(opt.summary_freq):
                    fetches["summary"] = sv.summary_op

                if should(opt.display_freq):
                    fetches["display"] = display_fetches

                results = sess.run(fetches, options=options, run_metadata=run_metadata)

                if should(opt.summary_freq):
                    print("recording summary")
                    sv.summary_writer.add_summary(results["summary"], results["global_step"])


                if should(opt.trace_freq):
                    print("recording trace")
                    sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])

                if should(opt.progress_freq):
                    # global_step will have the correct step count if we resume from a checkpoint
                    train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch)
                    train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1
                    rate = (step + 1) * opt.batch_size / (time.time() - start)
                    remaining = (max_steps - step) * opt.batch_size / rate
                    print("progress  epoch %d  step %d  image/sec %0.1f  remaining %dm" % (train_epoch, train_step, rate, remaining / 60))
                    print("discrim_loss", results["discrim_loss"])
                    print("gen_loss_GAN", results["gen_loss_GAN"])
                    print("gen_loss_L1", results["gen_loss_L1"])

                if should(opt.save_freq):
                    print("saving model")
                    saver.save(sess, os.path.join(opt.output_dir, "C2Pmodel"), global_step=sv.global_step)

                if sv.should_stop():
                    break
예제 #2
0
                outputs_np = data.deprocess(outputs_np)
            else:    
                # Convert Back to uint8
                outputs_np = np.uint8(2**8*outputs_np)
                
                 
            # sum each frame to get the resulting high-resolution image
            out_sum = out_sum + outputs_np # sum all values 
            if opt.is_frc: # if frc is true, there will be two summ-files to compute the FRC
                if(np.mod(step_i, 2)): # odd frames
                    out_sum_frc2 = out_sum_frc2 + outputs_np 
                else:
                    out_sum_frc1 = out_sum_frc1 + outputs_np # even frames

            # always export the inputs as tif!
            data.save_as_tif(inputs_np, experiment_name, network_name, 'inputs')
                
            if opt.is_tif:
                # save frames to TIF if necessary
                data.save_as_tif(outputs_np, experiment_name, network_name, 'outputs')
                
            if opt.is_csv:
                # get a list with all emitters greater than a certain per-frame intensity threshold value
                loc_list = np.int32(np.asarray(np.column_stack(np.where(outputs_np > np.mean(outputs_np)*.85))))
                int_list = outputs_np[loc_list[:,0],loc_list[:,1]]
                # cast it to 80/5nm pixelsize - 5 - because of the upsampling of factor 5
                upsampling = opt.scale_size/opt.roi_size
                loc_list = loc_list*(80/upsampling)
            
                loc_count = loc_list.shape[0]
                id_list =  last_index+np.arange(loc_count)
예제 #3
0
        else:
            # Convert Back to uint8
            outputs_np = np.uint8(2**8 * outputs_np)

        # sum each frame to get the resulting high-resolution image
        if opt.is_frc:  # if frc is true, there will be two summ-files to compute the FRC
            if (np.mod(step_i, 2)):  # odd frames
                out_sum_2 = out_sum_2 + outputs_np
            else:
                out_sum = out_sum + outputs_np  # even frames
        else:
            out_sum = out_sum + outputs_np  # sum all values

        if opt.is_tif:
            # save frames to TIF
            data.save_as_tif(inputs_np, outputs_np, experiment_name,
                             network_name)

        if opt.is_csv:
            # get a list with all emitters greater than a certain per-frame intensity threshold value
            loc_list = np.int32(
                np.asarray(
                    np.column_stack(
                        np.where(outputs_np > np.mean(outputs_np) * .85))))
            int_list = outputs_np[loc_list[:, 0], loc_list[:, 1]]
            # cast it to 80/5nm pixelsize - 5 - because of the upsampling of factor 5
            upsampling = opt.scale_size / opt.roi_size
            loc_list = loc_list * (80 / upsampling)

            loc_count = loc_list.shape[0]
            id_list = last_index + np.arange(loc_count)
            frame_list = np.ones(id_list.shape) * step