Example #1
0
                  cifar10_input_fn,
                  filenames=glob.glob(args.filenames),
                  batch_size=args.batch_size,
                  num_epochs=args.num_epochs if args.train else 1,
                  shuffle=True if args.train else False,
              ),
              fake_input_fn=lambda: (tf.random_normal([args.batch_size, 100])),
              hyper_params=Struct(
                  generator_learning_rate=2e-4,
                  generator_beta1=0.5,
                  generator_beta2=0.999,
                  discriminator_learning_rate=2e-4,
                  discriminator_beta1=0.5,
                  discriminator_beta2=0.999,
                  mode_seeking_loss_weight=0.1,
              ))

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(
        visible_device_list=args.gpu, allow_growth=True))

    if args.train:
        gan.train(model_dir=args.model_dir,
                  total_steps=args.total_steps,
                  save_checkpoint_steps=1000,
                  save_summary_steps=100,
                  log_tensor_steps=100,
                  config=config)

    if args.evaluate:
        gan.evaluate(model_dir=args.model_dir, config=config)
Example #2
0
   # train G (freeze discriminator)
   utils.make_trainable(d, False)
   for i in range(scheduler.get_gsteps()):
       real_imgs, real_vessels = next(train_batch_fetcher)
       g_x_batch, g_y_batch=utils.input2gan(real_imgs, real_vessels, d_out_shape)
       loss, acc = gan.train_on_batch(g_x_batch, g_y_batch)        
 
   # evaluate on validation set
   if n_round in rounds_for_evaluation:
       # D
       d_x_test, d_y_test=utils.input2discriminator(val_imgs, val_vessels, g.predict(val_imgs,batch_size=batch_size), d_out_shape)
       loss, acc=d.evaluate(d_x_test,d_y_test, batch_size=batch_size, verbose=0)
       utils.print_metrics(n_round+1, loss=loss, acc=acc, type='D')
       # G
       gan_x_test, gan_y_test=utils.input2gan(val_imgs, val_vessels, d_out_shape)
       loss,acc=gan.evaluate(gan_x_test,gan_y_test, batch_size=batch_size, verbose=0)
       utils.print_metrics(n_round+1, acc=acc, loss=loss, type='GAN')
       
       # save the weights
       g.save_weights(os.path.join(model_out_dir,"g_{}_{}_{}.h5".format(n_round,FLAGS.discriminator,FLAGS.ratio_gan2seg)))
      
   # update step sizes, learning rates
   scheduler.update_steps(n_round)
   K.set_value(d.optimizer.lr, scheduler.get_lr())    
   K.set_value(gan.optimizer.lr, scheduler.get_lr())    
   
   # evaluate on test images
   if n_round in rounds_for_evaluation:    
       generated=g.predict(test_imgs,batch_size=batch_size)
       generated=np.squeeze(generated, axis=3)
       vessels_in_mask, generated_in_mask = utils.pixel_values_in_mask(test_vessels, generated , test_masks)
Example #3
0
        loss1, acc1 = d1.evaluate(d1_x_test,
                                  d1_y_test,
                                  batch_size=batch_size,
                                  verbose=0)
        # loss2, acc2=d2.evaluate(d2_x_test,d2_y_test, batch_size=batch_size, verbose=0)
        # loss = (loss1 + loss2)/2
        # acc = (acc1 + acc2) / 2
        utils.print_metrics(n_round + 1, loss=loss1, acc=acc1, type='D')
        # G
        gan1_x_test, gan1_y_test = utils.input2gan(val_imgs,
                                                   val_vessels,
                                                   d_out_shape,
                                                   train_real=True)
        # gan2_x_test, gan2_y_test=utils.input2gan(val_imgs, val_vessels, d_out_shape, train_real=False)
        loss1, acc1 = gan1.evaluate(gan1_x_test,
                                    gan1_y_test,
                                    batch_size=batch_size,
                                    verbose=0)
        # loss2,acc2=gan2.evaluate(gan2_x_test, gan2_y_test, batch_size=batch_size, verbose=0)
        # loss = (loss2 + loss1) / 2
        # acc = (acc1 + acc2) / 2
        utils.print_metrics(n_round + 1, acc=acc1, loss=loss1, type='GAN')
        # save the model and weights with the best validation loss

        with open(
                os.path.join(
                    model_out_dir,
                    "g_{}_{}_{}.json".format(n_round, dataset,
                                             FLAGS.ratio_gan2seg)), 'w') as f:
            f.write(g.to_json())
        g.save_weights(
            os.path.join(