示例#1
0
def main():
	"""Create the model and start training.
	"""
	# Read CL arguments and snapshot the arguments into text file.
	args = get_arguments()
	utils.general.snapshot_arg(args)
	
	# The segmentation network is stride 8 by default.
	h, w = map(int, args.input_size.split(','))
	input_size = (h, w)
	innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))
	
	# Initialize the random seed.
	tf.set_random_seed(args.random_seed)
	
	# Create queue coordinator.
	coord = tf.train.Coordinator()
	
	# current step
	step_ph = tf.placeholder(dtype=tf.float32, shape=())
	
	# Set up tf session and initialize variables.
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	sess = tf.Session(config=config)
	
	# Load the data reader.
	with tf.device('/cpu:0'):
		with tf.name_scope('create_inputs'):
			reader = ImageReader(
				args.data_dir,
				args.data_list,
				input_size,
				args.random_scale,
				args.random_mirror,
				args.random_crop,
				args.ignore_label,
				IMG_MEAN)
			image_batch, label_batch = reader.dequeue(args.batch_size)
	
	'''
	image_batch => (N,H,W,C=3)
	label_batch => (N,H,W,1)
	'''
	# Shrink labels to the size of the network output.
	labels = tf.image.resize_nearest_neighbor(
		label_batch, innet_size, name='label_shrink')
	
	labels_flat = tf.reshape(labels, [-1, ])
	
	# Ignore the location where the label value is larger than args.num_classes.
	not_ignore_pixel = tf.less_equal(labels_flat, args.num_classes - 1)
	
	# Extract the indices of pixel where the gradients are propogated.
	pixel_inds = tf.squeeze(tf.where(not_ignore_pixel), 1)
	
	# Create network and predictions.
	outputs = model(image_batch,
					args.num_classes,
					args.is_training,
					args.use_global_status)
	
	# Grab variable names which should be restored from checkpoints.
	restore_var = [
		v for v in tf.global_variables()
		if 'block5' not in v.name or not args.not_restore_classifier
	]
	
	# Sum the losses from output branches.
	labels_gather = tf.to_int32(tf.gather(labels_flat, pixel_inds))
	seg_losses = []
	aff_losses = []
	for i, output in enumerate(outputs):   # outputs = (1,N,H,W,C)
		# Define softmax loss.
		tf.Print
		output_2d = tf.reshape(output, [-1, args.num_classes])
		output_gather = tf.gather(output_2d, pixel_inds)
		seg_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
			logits=output_gather, labels=labels_gather)
		seg_loss = tf.reduce_mean(seg_loss)
		seg_losses.append(seg_loss)
		
		# Define AFF loss.
		prob = tf.nn.softmax(output, axis=-1)
		edge_loss, not_edge_loss = lossx.affinity_loss(labels,
													   prob,
													   args.num_classes,
													   args.kld_margin)
		
		# Apply exponential decay to the AFF loss.
		dec = tf.pow(10.0, -step_ph / args.num_steps)
		aff_loss = tf.reduce_mean(edge_loss) * args.kld_lambda_1 * dec
		aff_loss += tf.reduce_mean(not_edge_loss) * args.kld_lambda_2 * dec
		aff_losses.append(aff_loss)
	
	# Define weight regularization loss.
	w = args.weight_decay
	l2_losses = [w * tf.nn.l2_loss(v) for v in tf.trainable_variables()
				 if 'weights' in v.name]
	
	# Sum all loss terms.
	mean_seg_loss = tf.add_n(seg_losses)
	mean_aff_loss = tf.add_n(aff_losses)
	mean_l2_loss = tf.add_n(l2_losses)
	reduced_loss = mean_seg_loss + mean_l2_loss + mean_aff_loss
	
	# Grab variable names which are used for training.
	all_trainable = tf.trainable_variables()
	fc_trainable = [v for v in all_trainable if 'block5' in v.name]  # lr*10
	base_trainable = [v for v in all_trainable if 'block5' not in v.name]  # lr*1
	
	# Computes gradients per iteration.
	grads = tf.gradients(reduced_loss, base_trainable + fc_trainable)
	grads_base = grads[:len(base_trainable)]
	grads_fc = grads[len(base_trainable):]
	
	# Define optimisation parameters.
	base_lr = tf.constant(args.learning_rate)
	learning_rate = tf.scalar_mul(
		base_lr,
		tf.pow((1 - step_ph / args.num_steps), args.power))
	
	opt_base = tf.train.MomentumOptimizer(learning_rate * 1.0, args.momentum)
	opt_fc = tf.train.MomentumOptimizer(learning_rate * 10.0, args.momentum)
	
	# Define tensorflow operations which apply gradients to update variables.
	train_op_base = opt_base.apply_gradients(
		zip(grads_base, base_trainable))
	train_op_fc = opt_fc.apply_gradients(
		zip(grads_fc, fc_trainable))
	train_op = tf.group(train_op_base, train_op_fc)
	
	# Process for visualisation.
	with tf.device('/cpu:0'):
		# Image summary for input image, ground-truth label and prediction.
		output_vis = tf.image.resize_nearest_neighbor(
			outputs[-1], tf.shape(image_batch)[1:3, ])
		output_vis = tf.argmax(output_vis, axis=3)
		output_vis = tf.expand_dims(output_vis, dim=3)
		output_vis = tf.cast(output_vis, dtype=tf.uint8)
		
		labels_vis = tf.cast(label_batch, dtype=tf.uint8)
		
		in_summary = tf.py_func(
			utils.general.inv_preprocess,
			[image_batch, IMG_MEAN],
			tf.uint8)
		gt_summary = tf.py_func(
			utils.general.decode_labels,
			[labels_vis, args.num_classes],
			tf.uint8)
		out_summary = tf.py_func(
			utils.general.decode_labels,
			[output_vis, args.num_classes],
			tf.uint8)
		# Concatenate image summaries in a row.
		total_summary = tf.summary.image(
			'images',
			tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]),
			max_outputs=args.batch_size)
		
		# Scalar summary for different loss terms.
		seg_loss_summary = tf.summary.scalar(
			'seg_loss', mean_seg_loss)
		aff_loss_summary = tf.summary.scalar(
			'aff_loss', mean_aff_loss)
		total_summary = tf.summary.merge_all()
		
		summary_writer = tf.summary.FileWriter(args.snapshot_dir,
											   graph=tf.get_default_graph())
	
	# # Set up tf session and initialize variables.
	# config = tf.ConfigProto()
	# config.gpu_options.allow_growth = True
	# sess = tf.Session(config=config)
	
	init = tf.global_variables_initializer()
	sess.run(init)
	
	# Saver for storing checkpoints of the model.
	saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
	
	# Load variables if the checkpoint is provided.
	if args.restore_from is not None:
		loader = tf.train.Saver(var_list=restore_var)
		load(loader, sess, args.restore_from)
	
	# Start queue threads.
	threads = tf.train.start_queue_runners(coord=coord, sess=sess)
	
	# Iterate over training steps.
	pbar = tqdm(range(args.num_steps))
	for step in pbar:
		start_time = time.time()
		feed_dict = {step_ph: step}
		
		step_loss = 0
		for it in range(args.iter_size):
			# Update summary periodically.
			if it == args.iter_size - 1 and step % args.update_tb_every == 0:
				sess_outs = [reduced_loss, total_summary, train_op]
				loss_value, summary, _ = sess.run(sess_outs,
												  feed_dict=feed_dict)
				summary_writer.add_summary(summary, step)
			else:
				sess_outs = [reduced_loss, train_op]
				loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict)
			
			step_loss += loss_value
		
		step_loss /= args.iter_size
		
		lr = sess.run(learning_rate, feed_dict=feed_dict)
		
		# Save trained model periodically.
		if step % args.save_pred_every == 0 and step > 0:
			save(saver, sess, args.snapshot_dir, step)
		
		duration = time.time() - start_time
		desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr)
		pbar.set_description(desc)
	
	coord.request_stop()
	coord.join(threads)
示例#2
0
def main():
    """Create the model and start the inference process.
	"""
    args = get_arguments()

    # Parse image processing arguments.
    input_size = parse_commastr(args.input_size)
    strides = parse_commastr(args.strides)
    assert (input_size is not None and strides is not None)
    h, w = input_size
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = ImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale.
            False,  # No random mirror.
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)
        image = reader.image
        image_list = reader.image_list
    image_batch = tf.expand_dims(image, dim=0)

    # Create multi-scale augmented datas.
    rescale_image_batches = []
    is_flipped = []
    scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1]
    for scale in scales:
        h_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale))
        w_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale))
        new_shape = tf.stack([h_new, w_new])
        new_image_batch = tf.image.resize_images(image_batch, new_shape)
        rescale_image_batches.append(new_image_batch)
        is_flipped.append(False)

    # Create horizontally flipped augmented datas.
    if args.flip_aug:
        for i in range(len(scales)):
            img = rescale_image_batches[i]
            is_flip = is_flipped[i]
            img = tf.squeeze(img, axis=0)
            flip_img = tf.image.flip_left_right(img)
            flip_img = tf.expand_dims(flip_img, axis=0)
            rescale_image_batches.append(flip_img)
            is_flipped.append(True)

    # Create input tensor to the Network
    crop_image_batch = tf.placeholder(
        name='crop_image_batch',
        shape=[1, input_size[0], input_size[1], 3],
        dtype=tf.float32)

    # Create network.
    outputs = model(crop_image_batch, args.num_classes, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[-1]
    output = tf.image.resize_bilinear(output,
                                      tf.shape(crop_image_batch)[1:3, ])
    output = tf.nn.softmax(output, dim=3)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    for step in range(num_steps):
        rescale_img_batches = sess.run(rescale_image_batches)
        # Final segmentation results (average across multiple scales).
        scale_ind = 2 if args.scale_aug else 0
        final_lab_size = list(rescale_img_batches[scale_ind].shape[1:])
        final_lab_size[-1] = args.num_classes
        final_lab_batch = np.zeros(final_lab_size)

        # Iterate over multiple scales.
        for img_batch, is_flip in zip(rescale_img_batches, is_flipped):
            img_size = img_batch.shape
            padimg_size = list(img_size)  # deep copy of img_size

            padimg_h, padimg_w = padimg_size[1:3]
            input_h, input_w = input_size

            if input_h > padimg_h:
                padimg_h = input_h
            if input_w > padimg_w:
                padimg_w = input_w
            # Update padded image size.
            padimg_size[1] = padimg_h
            padimg_size[2] = padimg_w
            padimg_batch = np.zeros(padimg_size, dtype=np.float32)
            img_h, img_w = img_size[1:3]
            padimg_batch[:, :img_h, :img_w, :] = img_batch

            # Create padded label array.
            lab_size = list(padimg_size)
            lab_size[-1] = args.num_classes
            lab_batch = np.zeros(lab_size, dtype=np.float32)
            lab_batch.fill(args.ignore_label)
            num_batch = np.zeros(lab_size[:-1], dtype=np.float32)

            stride_h, stride_w = strides
            npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1
            npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1

            # Create the ending index of each patch.
            patch_indh = np.linspace(input_h,
                                     padimg_h,
                                     npatches_h,
                                     dtype=np.int32)
            patch_indw = np.linspace(input_w,
                                     padimg_w,
                                     npatches_w,
                                     dtype=np.int32)

            for indh in patch_indh:
                for indw in patch_indw:
                    sh, eh = indh - input_h, indh  # start&end ind of H
                    sw, ew = indw - input_w, indw  # start&end ind of W
                    cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :]
                    feed_dict = {crop_image_batch: cropimg_batch}

                    out = sess.run(output, feed_dict=feed_dict)
                    lab_batch[:, sh:eh, sw:ew, :] += out
                    num_batch[:, sh:eh, sw:ew] += 1

            lab_batch /= num_batch[..., np.newaxis]
            lab_batch = lab_batch[0, :img_h, :img_w, :]
            # Rescale prediction back to original resolution.
            lab_batch = cv2.resize(lab_batch,
                                   (final_lab_size[1], final_lab_size[0]),
                                   interpolation=cv2.INTER_LINEAR)
            if is_flip:
                # Flipped prediction back to original orientation.
                lab_batch = lab_batch[:, ::-1, :]
            final_lab_batch += lab_batch

        final_lab_ind = np.argmax(final_lab_batch, axis=-1)
        final_lab_ind = final_lab_ind.astype(np.uint8)

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(final_lab_ind, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[final_lab_ind]
        Image.fromarray(color, mode='RGB').save(colorname)

    coord.request_stop()
    coord.join(threads)
def main():
  """Create the model and start the Inference process.
  """
  args = get_arguments()
    
  # Parse image processing arguments.
  input_size = parse_commastr(args.input_size)
  strides = parse_commastr(args.strides)
  assert(input_size is not None and strides is not None)
  h, w = input_size
  innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8)))


  # Create queue coordinator.
  coord = tf.train.Coordinator()

  # Load the data reader.
  with tf.name_scope('create_inputs'):
    reader = VMFImageReader(
        args.data_dir,
        args.data_list,
        None,
        False, # No random scale.
        False, # No random mirror.
        False, # No random crop, center crop instead
        args.ignore_label,
        IMG_MEAN)

    image = reader.image
    label = reader.label
    image_list = reader.image_list
  image_batch = tf.expand_dims(image, dim=0)
  label_batch = tf.expand_dims(label, dim=0)

  # Create input tensor to the Network
  crop_image_batch = tf.placeholder(
      name='crop_image_batch',
      shape=[1,input_size[0],input_size[1],3],
      dtype=tf.float32)

  # Create network and output prediction.
  outputs = model(crop_image_batch,
                  args.embedding_dim,
                  False,
                  True)

  # Grab variable names which should be restored from checkpoints.
  restore_var = [
    v for v in tf.global_variables() if 'crop_image_batch' not in v.name]
    
  # Output predictions.
  output = outputs[0]
  output = tf.image.resize_bilinear(
      output,
      [input_size[0], input_size[1]])

  # Input full-sized embedding
  label_input = tf.placeholder(
      tf.int32, shape=[1, None, None, 1])
  embedding_input = tf.placeholder(
      tf.float32, shape=[1, None, None, args.embedding_dim])
  embedding = common_utils.normalize_embedding(embedding_input)
  loc_feature = tf.placeholder(
      tf.float32, shape=[1, None, None, 2])
  rgb_feature = tf.placeholder(
      tf.float32, shape=[1, None, None, 3])

  # Combine embedding with location features and kmeans
  shape = tf.shape(embedding)
  cluster_labels = common_utils.initialize_cluster_labels(
      [args.num_clusters, args.num_clusters],
      [shape[1], shape[2]])
  embedding = tf.reshape(embedding, [-1, args.embedding_dim])
  labels = tf.reshape(label_input, [-1])
  cluster_labels = tf.reshape(cluster_labels, [-1])
  location_features = tf.reshape(loc_feature, [-1, 2])
  rgb_features = common_utils.normalize_embedding(
      tf.reshape(rgb_feature, [-1, 3])) / args.embedding_dim

    # Collect pixels of valid semantic classes.
  valid_pixels = tf.where(
      tf.not_equal(labels, args.ignore_label))
  labels = tf.squeeze(tf.gather(labels, valid_pixels), axis=1)
  cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels), axis=1)
  embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1)
  location_features = tf.squeeze(
      tf.gather(location_features, valid_pixels), axis=1)
  rgb_features = tf.squeeze(tf.gather(rgb_features, valid_pixels), axis=1)

  # Generate cluster labels via kmeans clustering.
  embedding_with_location = tf.concat(
      [embedding, location_features, rgb_features], 1)
  embedding_with_location = common_utils.normalize_embedding(
      embedding_with_location)
  cluster_labels = common_utils.kmeans_with_initial_labels(
      embedding_with_location,
      cluster_labels,
      args.num_clusters * args.num_clusters,
      args.kmeans_iterations)
  _, cluster_labels = tf.unique(cluster_labels)

  # Find pixels of majority semantic classes.
  select_pixels, prototype_labels = eval_utils.find_majority_label_index(
      labels, cluster_labels)

  # Calculate the prototype features.
  cluster_labels = tf.squeeze(tf.gather(cluster_labels, select_pixels), axis=1)
  embedding = tf.squeeze(tf.gather(embedding, select_pixels), axis=1)

  prototype_features = common_utils.calculate_prototypes_from_labels(
      embedding, cluster_labels)


  # Set up tf session and initialize variables. 
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
    
  sess.run(init)
  sess.run(tf.local_variables_initializer())
    
  # Load weights.
  loader = tf.train.Saver(var_list=restore_var)
  if args.restore_from is not None:
    load(loader, sess, args.restore_from)
    
  # Start queue threads.
  threads = tf.train.start_queue_runners(coord=coord, sess=sess)

  # Create directory for saving prototypes.
  save_dir = os.path.join(args.save_dir, 'prototypes')
  if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
    
  # Iterate over testing steps.
  with open(args.data_list, 'r') as listf:
    num_steps = len(listf.read().split('\n'))-1


  pbar = tqdm(range(num_steps))
  for step in pbar:
    image_batch_np, label_batch_np = sess.run(
        [image_batch, label_batch])

    img_size = image_batch_np.shape
    padded_img_size = list(img_size)  # deep copy of img_size

    if input_size[0] > padded_img_size[1]:
      padded_img_size[1] = input_size[0]
    if input_size[1] > padded_img_size[2]:
      padded_img_size[2] = input_size[1]
    padded_img_batch = np.zeros(padded_img_size,
                                dtype=np.float32)
    img_h, img_w = img_size[1:3]
    padded_img_batch[:, :img_h, :img_w, :] = image_batch_np

    stride_h, stride_w = strides
    npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1
    npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1

    # Create the ending index of each patch.
    patch_indh = np.linspace(
        input_size[0], padded_img_size[1], npatches_h, dtype=np.int32)
    patch_indw = np.linspace(
        input_size[1], padded_img_size[2], npatches_w, dtype=np.int32)
    
    # Create embedding holder.
    padded_img_size[-1] = args.embedding_dim
    embedding_all_np = np.zeros(padded_img_size,
                                dtype=np.float32)
    for indh in patch_indh:
      for indw in patch_indw:
        sh, eh = indh-input_size[0], indh  # start & end ind of H
        sw, ew = indw-input_size[1], indw  # start & end ind of W
        cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :]

        embedding_np = sess.run(output, feed_dict={
            crop_image_batch: cropimg_batch})
        embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np

    embedding_all_np = embedding_all_np[:, :img_h, :img_w, :]
    loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]])
    feed_dict = {label_input: label_batch_np,
                 embedding_input: embedding_all_np,
                 loc_feature: loc_feature_np,
                 rgb_feature: padded_img_batch}

    (batch_prototype_features_np,
     batch_prototype_labels_np) = sess.run(
      [prototype_features, prototype_labels],
      feed_dict=feed_dict)

    if step == 0:
      prototype_features_np = batch_prototype_features_np
      prototype_labels_np = batch_prototype_labels_np
    else:
      prototype_features_np = np.concatenate(
          [prototype_features_np, batch_prototype_features_np], axis=0)
      prototype_labels_np = np.concatenate(
          [prototype_labels_np,
           batch_prototype_labels_np], axis=0)


  print ('Total number of prototypes extracted: ',
         len(prototype_labels_np))
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'),
                     mode='w'), prototype_features_np)
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'),
                     mode='w'), prototype_labels_np)


  coord.request_stop()
  coord.join(threads)
def main():
    """Create the model and start the Inference process.
  """
    args = get_arguments()

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = VMFImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale.
            False,  # No random mirror.
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)

        image_list = reader.image_list
        image = reader.image
        cluster_label = reader.cluster_label
        loc_feature = reader.loc_feature
        height = reader.height
        width = reader.width

    # Create network and output prediction.
    outputs = model(tf.expand_dims(image, dim=0), args.embedding_dim, False,
                    True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [v for v in tf.global_variables()]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, tf.shape(image)[:2, ])
    embedding = common_utils.normalize_embedding(output)
    embedding = tf.squeeze(embedding, axis=0)

    image = image[:height, :width]
    embedding = tf.reshape(embedding[:height, :width],
                           [-1, args.embedding_dim])
    cluster_label = tf.reshape(cluster_label[:height, :width], [-1])
    loc_feature = tf.reshape(loc_feature[:height, :width], [-1, 2])

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features and kmeans
    embedding_with_location = tf.concat([embedding, loc_feature], 1)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)
    cluster_label = common_utils.kmeans_with_initial_labels(
        embedding_with_location, cluster_label,
        args.num_clusters * args.num_clusters, args.kmeans_iterations)
    _, cluster_labels = tf.unique(cluster_label)
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    cluster_labels = tf.reshape(cluster_labels, [height, width])

    # Predict semantic labels.
    similarities = tf.matmul(test_prototypes,
                             prototype_features,
                             transpose_b=True)
    _, k_predictions = tf.nn.top_k(similarities,
                                   k=args.k_in_nearest_neighbors,
                                   sorted=True)

    prototype_semantic_predictions = eval_utils.k_nearest_neighbors(
        k_predictions, prototype_labels)
    semantic_predictions = tf.gather(prototype_semantic_predictions,
                                     cluster_labels)
    #  semantic_predictions = tf.squeeze(semantic_predictions)

    # Visualize embedding using PCA
    embedding = vis_utils.pca(
        tf.reshape(embedding, [1, height, width, args.embedding_dim]))
    embedding = ((embedding - tf.reduce_min(embedding)) /
                 (tf.reduce_max(embedding) - tf.reduce_min(embedding)))
    embedding = tf.cast(embedding * 255, tf.uint8)
    embedding = tf.squeeze(embedding, axis=0)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    cluster_dir = os.path.join(args.save_dir, 'cluster')
    embedding_dir = os.path.join(args.save_dir, 'embedding')
    patch_dir = os.path.join(args.save_dir, 'test_patches')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)
    if not os.path.isdir(cluster_dir):
        os.makedirs(cluster_dir)
    if not os.path.isdir(embedding_dir):
        os.makedirs(embedding_dir)
    if not os.path.isdir(patch_dir):
        os.makedirs(patch_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))

    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    f = html_helper.open_html_for_write(
        os.path.join(args.save_dir, 'index.html'),
        'Visualization for Segment Collaging')
    for step in range(num_steps):
        image_np, semantic_predictions_np, cluster_labels_np, embedding_np, k_predictions_np = sess.run(
            [
                image, semantic_predictions, cluster_labels, embedding,
                k_predictions
            ],
            feed_dict=feed_dict)

        imgname = os.path.basename(image_list[step])
        basename = imgname.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(semantic_predictions_np, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[semantic_predictions_np]
        Image.fromarray(color, mode='RGB').save(colorname)

        clustername = os.path.join(cluster_dir, basename)
        cluster = colormap[cluster_labels_np]
        Image.fromarray(cluster, mode='RGB').save(clustername)

        embeddingname = os.path.join(embedding_dir, basename)
        Image.fromarray(embedding_np, mode='RGB').save(embeddingname)

        image_np = (image_np + IMG_MEAN).astype(np.uint8)
        for i in range(np.max(cluster_labels_np) + 1):
            image_temp = copy.deepcopy(image_np)
            image_temp[cluster_labels_np != i] = 0
            coords = np.where(cluster_labels_np == i)
            crop = image_temp[np.min(coords[0]):np.max(coords[0]),
                              np.min(coords[1]):np.max(coords[1])]
            scipy.misc.imsave(
                patch_dir + '/' + basename + str(i).zfill(3) + '.png', crop)

        html_helper.write_vmf_to_html(
            f, './images/' + imgname, './labels/' + basename,
            './color/' + basename, './cluster/' + basename,
            './embedding/' + basename, './test_patches/' + basename,
            './patches/', k_predictions_np)

        if (step + 1) % 100 == 0:
            print('Processed batches: ', (step + 1), '/', num_steps)

    html_helper.close_html(f)
    coord.request_stop()
    coord.join(threads)
def main():
  """Create the model and start the Inference process.
  """
  args = get_arguments()
    
  # Parse image processing arguments.
  input_size = parse_commastr(args.input_size)
  strides = parse_commastr(args.strides)
  assert(input_size is not None and strides is not None)
  h, w = input_size
  innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8)))


  # Create queue coordinator.
  coord = tf.train.Coordinator()

  # Load the data reader.
  with tf.name_scope('create_inputs'):
    reader = VMFImageReader(
        args.data_dir,
        args.data_list,
        None,
        False, # No random scale.
        False, # No random mirror.
        False, # No random crop, center crop instead
        args.ignore_label,
        IMG_MEAN)

    image_batch = tf.expand_dims(reader.image, dim=0)
    label_batch = tf.expand_dims(reader.label, dim=0)
    cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0)
    loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0)

  # Create network and output prediction.
  outputs = model(image_batch,
                  args.embedding_dim,
                  False,
                  True)

  # Grab variable names which should be restored from checkpoints.
  restore_var = [
    v for v in tf.global_variables() if 'crop_image_batch' not in v.name]
    
  # Output predictions.
  output = outputs[0]
  output = tf.image.resize_bilinear(
      output,
      tf.shape(image_batch)[1:3,])
  embedding = common_utils.normalize_embedding(output)

  shape = embedding.get_shape().as_list()
  batch_size = shape[0]

  labels = label_batch
  initial_cluster_labels = cluster_label_batch[0, :, :]
  location_features = tf.reshape(loc_feature_batch[0, :, :], [-1, 2])

  prototype_feature_list = []
  prototype_label_list = []
  for bs in range(batch_size):
    cur_labels = tf.reshape(labels[bs], [-1])
    cur_cluster_labels = tf.reshape(initial_cluster_labels, [-1])
    cur_embedding = tf.reshape(embedding[bs], [-1, args.embedding_dim])

    (prototype_features,
     prototype_labels,
     _) = eval_utils.extract_trained_prototypes(
         cur_embedding, location_features, cur_cluster_labels,
         args.num_clusters * args.num_clusters,
         args.kmeans_iterations, cur_labels,
         1, args.ignore_label,
         'semantic')

    prototype_feature_list.append(prototype_features)
    prototype_label_list.append(prototype_labels)

  prototype_features = tf.concat(prototype_feature_list, axis=0)
  prototype_labels = tf.concat(prototype_label_list, axis=0)

    
  # Set up tf session and initialize variables. 
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
    
  sess.run(init)
  sess.run(tf.local_variables_initializer())
    
  # Load weights.
  loader = tf.train.Saver(var_list=restore_var)
  if args.restore_from is not None:
    load(loader, sess, args.restore_from)
    
  # Start queue threads.
  threads = tf.train.start_queue_runners(coord=coord, sess=sess)

  # Create directory for saving prototypes.
  save_dir = os.path.join(args.save_dir, 'prototypes')
  if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
    
  # Iterate over testing steps.
  with open(args.data_list, 'r') as listf:
    num_steps = len(listf.read().split('\n'))-1

  for step in range(num_steps):
    (batch_prototype_features_np,
     batch_prototype_labels_np) = sess.run(
      [prototype_features, prototype_labels])

    if step == 0:
      prototype_features_np = batch_prototype_features_np
      prototype_labels_np = batch_prototype_labels_np
    else:
      prototype_features_np = np.concatenate(
          [prototype_features_np, batch_prototype_features_np], axis=0)
      prototype_labels_np = np.concatenate(
          [prototype_labels_np,
           batch_prototype_labels_np], axis=0)

    if (step + 1) % 100 == 0:
      print('Processed batches: ', (step + 1), '/', num_steps)

  print ('Total number of prototypes extracted: ',
         len(prototype_labels_np))
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'),
                     mode='w'), prototype_features_np)
  np.save(
      tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'),
                     mode='w'), prototype_labels_np)


  coord.request_stop()
  coord.join(threads)
def main():
  """Create the model and start training.
  """
  # Read CL arguments and snapshot the arguments into text file.
  args = get_arguments()
  utils.general.snapshot_arg(args)
    
  # The segmentation network is stride 8 by default.
  h, w = map(int, args.input_size.split(','))
  input_size = (h, w)
  innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8)))
    
  # Initialize the random seed.
  tf.set_random_seed(args.random_seed)
    
  # Create queue coordinator.
  coord = tf.train.Coordinator()

  # current step
  step_ph = tf.placeholder(dtype=tf.float32, shape=())

  # Load the data reader.
  with tf.device('/cpu:0'):
    with tf.name_scope('create_inputs'):
      reader = SegSortUnsupImageReader(
          args.data_dir,
          args.data_list,
          input_size,
          args.random_scale,
          args.random_mirror,
          args.random_crop,
          args.ignore_label,
          IMG_MEAN)

      image_batch, _, cluster_label_batch = (
          reader.dequeue(args.batch_size))

  # Shrink labels to the size of the network output.
  cluster_labels = tf.image.resize_nearest_neighbor(
      cluster_label_batch, innet_size)
  
  # images_mgpu = custom_split(image_batch, args.num_gpu)

  # Create network and predictions.
  with tf.device('/gpu:1'):
    outputs = model(image_batch,
                    args.embedding_dim,
                    args.is_training,
                    args.use_global_status)

  # Grab variable names which should be restored from checkpoints.
  restore_var = [
    v for v in tf.global_variables()
      if 'block5' not in v.name or not args.not_restore_classifier
  ]

  # Collect embedding from each gpu.
  with tf.device('/gpu:{:d}'.format(args.num_gpu-1)):
    # embedding_list = [output[0] for output in outputs]
    # embedding = tf.concat(embedding_list, axis=0)

    # Add Unsupervised SegSort loss.
    seg_losses = train_utils.add_unsupervised_segsort_loss(
        outputs[0], args.concentration, cluster_labels, )
                                            
    # Define weight regularization loss.
    w = args.weight_decay
    l2_losses = [w*tf.nn.l2_loss(v) for v in tf.trainable_variables()
                  if 'weights' in v.name]

    # Sum all loss terms.
    mean_seg_loss = seg_losses
    mean_l2_loss = tf.add_n(l2_losses)
    reduced_loss = mean_seg_loss + mean_l2_loss

  # Grab variable names which are used for training.
  all_trainable = tf.trainable_variables()
  fc_trainable = [v for v in all_trainable if 'block5' in v.name] # lr*10
  base_trainable = [v for v in all_trainable if 'block5' not in v.name] # lr*1

  # Computes gradients per iteration.
  grads = tf.gradients(reduced_loss, base_trainable+fc_trainable)
  grads_base = grads[:len(base_trainable)]
  grads_fc = grads[len(base_trainable):]

  # Define optimisation parameters.
  base_lr = tf.constant(args.learning_rate)
  pow_till = args.num_steps
  pow_till = 100000
  learning_rate = tf.scalar_mul(
    base_lr,
    tf.pow((1-step_ph/pow_till), args.power))

  opt_base = tf.train.MomentumOptimizer(learning_rate*1.0, args.momentum)
  opt_fc = tf.train.MomentumOptimizer(learning_rate*10.0, args.momentum)

  # Define tensorflow operations which apply gradients to update variables.
  train_op_base = opt_base.apply_gradients(zip(grads_base, base_trainable))
  train_op_fc = opt_fc.apply_gradients(zip(grads_fc, fc_trainable))
  train_op = tf.group(train_op_base, train_op_fc)

  # Process for visualisation.
  with tf.device('/cpu:0'):
    # Image summary for input image, ground-truth label and prediction.
    output_vis = tf.image.resize_nearest_neighbor(
        outputs[-1], tf.shape(image_batch)[1:3,])
    output_vis = tf.argmax(output_vis, axis=3)
    output_vis = tf.expand_dims(output_vis, dim=3)
    output_vis = tf.cast(output_vis, dtype=tf.uint8)
    
    labels_vis = tf.cast(cluster_label_batch, dtype=tf.uint8)
 
    in_summary = tf.py_func(
        utils.general.inv_preprocess,
        [image_batch, IMG_MEAN],
        tf.uint8)
    gt_summary = tf.py_func(
        utils.general.decode_labels,
        [labels_vis, args.num_classes],
        tf.uint8)
    out_summary = tf.py_func(
        utils.general.decode_labels,
        [output_vis, args.num_classes],
        tf.uint8)
    # Concatenate image summaries in a row.
    total_summary = tf.summary.image(
        'images', 
        tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]), 
        max_outputs=args.batch_size)

    # Scalar summary for different loss terms.
    seg_loss_summary = tf.summary.scalar(
        'seg_loss', mean_seg_loss)
    total_summary = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(
        args.snapshot_dir,
        graph=tf.get_default_graph())
    
  # Set up tf session and initialize variables. 
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
    
  sess.run(init)
    
  # Saver for storing checkpoints of the model.
  saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10)
    
  # Load variables if the checkpoint is provided.
  if args.restore_from is not None and len(args.restore_from) > 0:
    loader = tf.train.Saver(var_list=restore_var)
    load(loader, sess, args.restore_from)
    
  # Start queue threads.
  threads = tf.train.start_queue_runners(
      coord=coord, sess=sess)

  # Iterate over training steps.
  pbar = tqdm(range(args.num_steps))
  for step in pbar:
    start_time = time.time()
    feed_dict = {step_ph : step}

    step_loss = 0
    for it in range(args.iter_size):
      # Update summary periodically.
      if it == args.iter_size-1 and step % args.update_tb_every == 0:
        sess_outs = [reduced_loss, total_summary, train_op]
        loss_value, summary, _ = sess.run(sess_outs,
                                          feed_dict=feed_dict)
        summary_writer.add_summary(summary, step)
      else:
        sess_outs = [reduced_loss, train_op]
        loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict)

      step_loss += loss_value

    step_loss /= args.iter_size

    lr = sess.run(learning_rate, feed_dict=feed_dict)

    # Save trained model periodically.
    if step % args.save_pred_every == 0 and step > 0:
      save(saver, sess, args.snapshot_dir, step)

    duration = time.time() - start_time
    desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr)
    pbar.set_description(desc)

  coord.request_stop()
  coord.join(threads)
def main():
  """Create the model and start the Inference process.
  """
  args = get_arguments()

  #TODO:5. postprocession and save
  # Parse image processing arguments.
  print('get model!')
  input_size = parse_commastr(args.input_size)
  strides = parse_commastr(args.strides)
  assert(input_size is not None and strides is not None)
  h, w = input_size
  #innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8)))

  # Create input tensor to the Network.
  crop_image_batch = tf.placeholder(
      name='crop_image_batch',
      shape=[8,input_size[0],input_size[1],3],
      dtype=tf.float32)

  # Create network and output prediction.
  outputs = model(crop_image_batch,
                  args.num_classes,
                  False,
                  True)

  # Grab variable names which should be restored from checkpoints.
  restore_var = [
    v for v in tf.global_variables() if 'crop_image_batch' not in v.name]
    
  # Output predictions.
  output = outputs[-1]
  output = tf.image.resize_bilinear(
      output,
      tf.shape(crop_image_batch)[1:3,])
  output = tf.nn.softmax(output, dim=3)
    
  # Set up tf session and initialize variables. 
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
    
  sess.run(init)
  sess.run(tf.local_variables_initializer())
    
  # Load weights.
  loader = tf.train.Saver(var_list=restore_var)
  if args.restore_from is not None:
    loadckpt(loader, sess, args.restore_from)
    
  # Start queue threads.
  #threads = tf.train.start_queue_runners(coord=coord, sess=sess)

  for id in range(70):
      print('-' * 30)
      print('preprocessing test data...' + str(id))
      print('-' * 30)

      #1. load a medpy file from mem
      imgs_test, img_test_header = load(args.data + str(id) + '.nii')
      mm=np.zeros((input_size[0],input_size[1],imgs_test.shape[2]))
      mm[:imgs_test.shape[0],:imgs_test.shape[1],:imgs_test.shape[2]]=imgs_test
      imgs_test=mm


      #  load liver mask
      mask, mask_header = load(args.liver_path + str(id) + '-ori.nii')
      mask[mask == 2] = 1
      mask = ndimage.binary_dilation(mask, iterations=1).astype(mask.dtype)
      print('-' * 30)
      print('Predicting masks on test data...' + str(id))
      print('-' * 30)
      index = np.where(mask == 1)
      mini = np.min(index, axis=-1)
      maxi = np.max(index, axis=-1)

      batch = 1
      img_deps = input_size[0]
      img_rows = input_size[1]
      img_cols = 8

      window_cols = (img_cols / 4)
      count = 0
      box_test = np.zeros((batch, img_deps, img_rows, img_cols, 1), dtype="float32")
      x = imgs_test.shape[0]
      y = imgs_test.shape[1]
      z = imgs_test.shape[2]
      right_cols = int(min(z, maxi[2] + 10) - img_cols)
      left_cols = max(0, min(mini[2] - 5, right_cols))
      score = np.zeros((x, y, z, 3), dtype='float32')
      score_num = np.zeros((x, y, z, 3), dtype='int16')
      for cols in xrange(left_cols, right_cols + window_cols, window_cols):
          # print ('and', z-img_cols,z)
          if cols > z - img_cols:
              patch_test = imgs_test[0:img_deps, 0:img_rows, z - img_cols:z]
              box_test[count, :, :, :, 0] = patch_test
              incol = box_test.shape[3]
              box_testt = tans2d(box_test, incol)
              box_testt = (box_testt + 250) * 255 / 500
              box_testt -= np.array((122.675, 122.669, 122.008), dtype=np.float32)
              # print ('final', img_cols-window_cols, img_cols)
              feed_dict = {crop_image_batch: box_testt}
              patch_test_mask = sess.run(output, feed_dict=feed_dict)

              patch_test_mask = trans3d(patch_test_mask, incol)
              patch_test_mask = patch_test_mask[:, :, :, 1:-1, :]

              for i in xrange(batch):
                  score[0:img_deps, 0:img_rows, z - img_cols + 1:z - 1, :] += patch_test_mask[i]
                  score_num[0:img_deps, 0:img_rows, z - img_cols + 1:z - 1, :] += 1
          else:
              patch_test = imgs_test[0:img_deps, 0:img_rows, cols:cols + img_cols]
              # print(patch_test.shape)
              box_test[count, :, :, :, 0] = patch_test
              incol = box_test.shape[3]
              box_testt = tans2d(box_test, incol)
              box_testt = (box_testt + 250) * 255 / 500
              box_testt -= np.array((122.675, 122.669, 122.008), dtype=np.float32)
              feed_dict = {crop_image_batch: box_testt}
              patch_test_mask = sess.run(output, feed_dict=feed_dict)
              patch_test_mask = trans3d(patch_test_mask, incol)

              patch_test_mask = patch_test_mask[:, :, :, 1:-1, :]

              for i in xrange(batch):
                  score[0:img_deps, 0:img_rows, cols + 1:cols + img_cols - 1, :] += patch_test_mask[i]
                  score_num[0:img_deps, 0:img_rows, cols + 1:cols + img_cols - 1, :] += 1
      score = score / (score_num + 1e-4)
      result1 = score[:512, :512, :, 1]
      result2 = score[:512, :512, :, 2]

      result1[result1 >= args.thres_liver] = 1
      result1[result1 < args.thres_liver] = 0
      result2[result2 >= args.thres_tumor] = 1
      result2[result2 < args.thres_tumor] = 0
      result1[result2 == 1] = 1

      print('-' * 30)
      print('Postprocessing on mask ...' + str(id))
      print('-' * 30)

      #  preserve the largest liver
      Segmask = result2
      box = []
      [liver_res, num] = measure.label(result1, return_num=True)
      region = measure.regionprops(liver_res)
      for i in range(num):
          box.append(region[i].area)
      label_num = box.index(max(box)) + 1
      liver_res[liver_res != label_num] = 0
      liver_res[liver_res == label_num] = 1

      #  preserve the largest liver
      mask = ndimage.binary_dilation(mask, iterations=1).astype(mask.dtype)
      box = []
      [liver_labels, num] = measure.label(mask, return_num=True)
      region = measure.regionprops(liver_labels)
      for i in range(num):
          box.append(region[i].area)
      label_num = box.index(max(box)) + 1
      liver_labels[liver_labels != label_num] = 0
      liver_labels[liver_labels == label_num] = 1
      liver_labels = ndimage.binary_fill_holes(liver_labels).astype(int)

      #  preserve tumor within ' largest liver' only
      Segmask = Segmask * liver_labels
      Segmask = ndimage.binary_fill_holes(Segmask).astype(int)
      Segmask = np.array(Segmask, dtype='uint8')
      liver_res = np.array(liver_res, dtype='uint8')
      liver_res = ndimage.binary_fill_holes(liver_res).astype(int)
      liver_res[Segmask == 1] = 2
      liver_res = np.array(liver_res, dtype='uint8')
      save(liver_res, args.save_path + 'test-segmentation-' + str(id) + '.nii', img_test_header)
示例#8
0
def main():
    """Create the model and start the Inference process.
	"""
    args = get_arguments()

    # Parse image processing arguments.
    input_size = parse_commastr(args.input_size)
    strides = parse_commastr(args.strides)
    assert (input_size is not None and strides is not None)
    h, w = input_size
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = ImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale.
            False,  # No random mirror.
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)

        image = reader.image
        image_list = reader.image_list

    image_batch = tf.expand_dims(image, dim=0)

    # Create input tensor to the Network.
    crop_image_batch = tf.placeholder(
        name='crop_image_batch',
        shape=[1, input_size[0], input_size[1], 3],
        dtype=tf.float32)

    # Create network and output prediction.
    outputs = model(crop_image_batch, args.num_classes, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[-1]
    output = tf.image.resize_bilinear(output,
                                      tf.shape(crop_image_batch)[1:3, ])

    output = tf.nn.softmax(output, axis=3)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    for step in range(num_steps):
        img_batch = sess.run(image_batch)
        img_size = img_batch.shape
        padimg_size = list(img_size)  # deep copy of img_size

        padimg_h, padimg_w = padimg_size[1:3]
        input_h, input_w = input_size

        if input_h > padimg_h:
            padimg_h = input_h
        if input_w > padimg_w:
            padimg_w = input_w

        # Update padded image size.
        padimg_size[1] = padimg_h
        padimg_size[2] = padimg_w
        padimg_batch = np.zeros(padimg_size, dtype=np.float32)
        img_h, img_w = img_size[1:3]
        padimg_batch[:, :img_h, :img_w, :] = img_batch

        # Create padded label array.
        lab_size = list(padimg_size)
        lab_size[-1] = args.num_classes
        lab_batch = np.zeros(lab_size, dtype=np.float32)
        lab_batch.fill(args.ignore_label)

        stride_h, stride_w = strides
        npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1
        npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1

        # Crate the ending index of each patch.
        patch_indh = np.linspace(input_h, padimg_h, npatches_h, dtype=np.int32)
        patch_indw = np.linspace(input_w, padimg_w, npatches_w, dtype=np.int32)

        for indh in patch_indh:
            for indw in patch_indw:
                sh, eh = indh - input_h, indh  # start&end ind of H
                sw, ew = indw - input_w, indw  # start&end ind of W
                cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :]
                feed_dict = {crop_image_batch: cropimg_batch}

                out = sess.run(output, feed_dict=feed_dict)
                lab_batch[:, sh:eh, sw:ew, :] += out

        lab_batch = lab_batch[0, :img_h, :img_w, :]
        lab_batch = np.argmax(lab_batch, axis=-1)
        lab_batch = lab_batch.astype(np.uint8)

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(lab_batch, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[lab_batch]
        Image.fromarray(color, mode='RGB').save(colorname)

    coord.request_stop()
    coord.join(threads)
示例#9
0
def main():
    """Creates the model and start the inference process."""
    args = get_arguments()

    # Parse image processing arguments.
    input_size = parse_commastr(args.input_size)
    strides = parse_commastr(args.strides)
    assert (input_size is not None and strides is not None)
    h, w = input_size
    innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8)))

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = SegSortImageReader(
            args.data_dir,
            args.data_list,
            None,
            False,  # No random scale
            False,  # No random mirror
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)
        image = reader.image[:reader.height, :reader.width]
        image_list = reader.image_list
    image_batch = tf.expand_dims(image, dim=0)

    # Create multi-scale augmented datas.
    rescale_image_batches = []
    is_flipped = []
    scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1]
    for scale in scales:
        h_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale))
        w_new = tf.to_int32(
            tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale))
        new_shape = tf.stack([h_new, w_new])
        new_image_batch = tf.image.resize_images(image_batch, new_shape)
        rescale_image_batches.append(new_image_batch)
        is_flipped.append(False)

    # Create horizontally flipped augmented datas.
    if args.flip_aug:
        for i in range(len(scales)):
            img = rescale_image_batches[i]
            is_flip = is_flipped[i]
            img = tf.squeeze(img, axis=0)
            flip_img = tf.image.flip_left_right(img)
            flip_img = tf.expand_dims(flip_img, axis=0)
            rescale_image_batches.append(flip_img)
            is_flipped.append(True)

    # Create input tensor to the Network.
    crop_image_batch = tf.placeholder(
        name='crop_image_batch',
        shape=[1, input_size[0], input_size[1], 3],
        dtype=tf.float32)

    # Create network.
    outputs = model(crop_image_batch, args.embedding_dim, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, [input_size[0], input_size[1]])
    embedding = common_utils.normalize_embedding(output)

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features and kmeans
    shape = embedding.get_shape().as_list()
    loc_feature = tf.expand_dims(
        common_utils.generate_location_features([shape[1], shape[2]], 'float'),
        0)
    embedding_with_location = tf.concat([embedding, loc_feature], 3)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)

    # Perform Kmeans clustering and extract prototypes.
    cluster_labels = common_utils.kmeans(
        embedding_with_location, [args.num_clusters, args.num_clusters],
        args.kmeans_iterations)
    _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1]))
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    # Predict semantic labels.
    similarities = tf.matmul(test_prototypes,
                             prototype_features,
                             transpose_b=True)
    _, k_predictions = tf.nn.top_k(similarities,
                                   k=args.k_in_nearest_neighbors,
                                   sorted=True)
    k_predictions = tf.gather(prototype_labels, k_predictions)
    k_predictions = tf.gather(k_predictions, cluster_labels)
    k_predictions = tf.reshape(
        k_predictions,
        [shape[0], shape[1], shape[2], args.k_in_nearest_neighbors])

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels.
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))
    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    pbar = tqdm(range(num_steps))
    for step in pbar:
        rescale_img_batches = sess.run(rescale_image_batches)
        # Final segmentation results (average across multiple scales).
        scale_ind = 2 if args.scale_aug else 0
        final_lab_size = list(rescale_img_batches[scale_ind].shape[1:])
        final_lab_size[-1] = args.num_classes
        final_lab_batch = np.zeros(final_lab_size)

        # Iterate over multiple scales.
        for img_batch, is_flip in zip(rescale_img_batches, is_flipped):
            img_size = img_batch.shape
            padimg_size = list(img_size)  # deep copy of img_size

            padimg_h, padimg_w = padimg_size[1:3]
            input_h, input_w = input_size

            if input_h > padimg_h:
                padimg_h = input_h
            if input_w > padimg_w:
                padimg_w = input_w
            # Update padded image size.
            padimg_size[1] = padimg_h
            padimg_size[2] = padimg_w
            padimg_batch = np.zeros(padimg_size, dtype=np.float32)
            img_h, img_w = img_size[1:3]
            padimg_batch[:, :img_h, :img_w, :] = img_batch

            stride_h, stride_w = strides
            npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1
            npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1

            # Create padded prediction array.
            pred_size = list(padimg_size)
            pred_size[-1] = args.num_classes
            predictions_np = np.zeros(pred_size, dtype=np.int32)

            # Create the ending index of each patch.
            patch_indh = np.linspace(input_h,
                                     padimg_h,
                                     npatches_h,
                                     dtype=np.int32)
            patch_indw = np.linspace(input_w,
                                     padimg_w,
                                     npatches_w,
                                     dtype=np.int32)

            pred_size[-1] = args.embedding_dim
            for indh in patch_indh:
                for indw in patch_indw:
                    sh, eh = indh - input_h, indh  # start & end ind of H
                    sw, ew = indw - input_w, indw  # start & end ind of W
                    cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :]
                    feed_dict[crop_image_batch] = cropimg_batch

                    k_predictions_np = sess.run(k_predictions,
                                                feed_dict=feed_dict)
                    # Sum up KNN votes.
                    # This is the speed bottleneck for multiscale inference.
                    # Use singlescale inference for fast results.
                    # TODO: Either compute on GPU or change a way of implementation.
                    for c in range(args.num_classes):
                        predictions_np[:, sh:eh, sw:ew, c] += np.sum(
                            (k_predictions_np == c).astype(np.int), axis=3)

            predictions_np = predictions_np[0, :img_h, :img_w, :]
            lab_batch = predictions_np.astype(np.float32)
            # Rescale prediction back to original resolution.
            lab_batch = cv2.resize(lab_batch,
                                   (final_lab_size[1], final_lab_size[0]),
                                   interpolation=cv2.INTER_LINEAR)
            if is_flip:
                # Flipped prediction back to original orientation.
                lab_batch = lab_batch[:, ::-1, :]
            final_lab_batch += lab_batch

        final_lab_ind = np.argmax(final_lab_batch, axis=-1)
        final_lab_ind = final_lab_ind.astype(np.uint8)

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(final_lab_ind, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[final_lab_ind]
        Image.fromarray(color, mode='RGB').save(colorname)

    coord.request_stop()
    coord.join(threads)
def main():
    """Create the model and start the Inference process."""
    args = get_arguments()

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    # Load the data reader.
    with tf.name_scope('create_inputs'):
        reader = SegSortImageReader(
            args.data_dir,
            args.data_list,
            parse_commastr(args.input_size),
            False,  # No random scale
            False,  # No random mirror
            False,  # No random crop, center crop instead
            args.ignore_label,
            IMG_MEAN)

        image_list = reader.image_list
        image_batch = tf.expand_dims(reader.image, dim=0)
        label_batch = tf.expand_dims(reader.label, dim=0)
        cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0)
        loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0)
        height = reader.height
        width = reader.width

    # Create network and output prediction.
    outputs = model(image_batch, args.embedding_dim, False, True)

    # Grab variable names which should be restored from checkpoints.
    restore_var = [
        v for v in tf.global_variables() if 'crop_image_batch' not in v.name
    ]

    # Output predictions.
    output = outputs[0]
    output = tf.image.resize_bilinear(output, tf.shape(image_batch)[1:3, ])
    embedding = common_utils.normalize_embedding(output)

    # Prototype placeholders.
    prototype_features = tf.placeholder(tf.float32,
                                        shape=[None, args.embedding_dim])
    prototype_labels = tf.placeholder(tf.int32)

    # Combine embedding with location features.
    embedding_with_location = tf.concat([embedding, loc_feature_batch], 3)
    embedding_with_location = common_utils.normalize_embedding(
        embedding_with_location)

    # Kmeans clustering.
    cluster_labels = common_utils.kmeans(
        embedding_with_location, [args.num_clusters, args.num_clusters],
        args.kmeans_iterations)
    test_prototypes = common_utils.calculate_prototypes_from_labels(
        embedding, cluster_labels)

    # Predict semantic labels.
    semantic_predictions, _ = eval_utils.predict_semantic_instance_labels(
        cluster_labels, test_prototypes, prototype_features, prototype_labels,
        None, args.k_in_nearest_neighbors)
    semantic_predictions = tf.cast(semantic_predictions, tf.uint8)
    semantic_predictions = tf.squeeze(semantic_predictions)

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)
    sess.run(tf.local_variables_initializer())

    # Load weights.
    loader = tf.train.Saver(var_list=restore_var)
    if args.restore_from is not None:
        load(loader, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Get colormap.
    map_data = scipy.io.loadmat(args.colormap)
    key = os.path.basename(args.colormap).replace('.mat', '')
    colormap = map_data[key]
    colormap *= 255
    colormap = colormap.astype(np.uint8)

    # Create directory for saving predictions.
    pred_dir = os.path.join(args.save_dir, 'gray')
    color_dir = os.path.join(args.save_dir, 'color')
    if not os.path.isdir(pred_dir):
        os.makedirs(pred_dir)
    if not os.path.isdir(color_dir):
        os.makedirs(color_dir)

    # Iterate over testing steps.
    with open(args.data_list, 'r') as listf:
        num_steps = len(listf.read().split('\n')) - 1

    # Load prototype features and labels.
    prototype_features_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_features.npy'))
    prototype_labels_np = np.load(
        os.path.join(args.prototype_dir, 'prototype_labels.npy'))

    feed_dict = {
        prototype_features: prototype_features_np,
        prototype_labels: prototype_labels_np
    }

    for step in tqdm(range(num_steps)):
        semantic_predictions_np, height_np, width_np = sess.run(
            [semantic_predictions, height, width], feed_dict=feed_dict)

        semantic_predictions_np = semantic_predictions_np[:height_np, :
                                                          width_np]

        basename = os.path.basename(image_list[step])
        basename = basename.replace('jpg', 'png')

        predname = os.path.join(pred_dir, basename)
        Image.fromarray(semantic_predictions_np, mode='L').save(predname)

        colorname = os.path.join(color_dir, basename)
        color = colormap[semantic_predictions_np]
        Image.fromarray(color, mode='RGB').save(colorname)

    coord.request_stop()
    coord.join(threads)