Beispiel #1
0
  def testShapeIsCorrectAfterOp(self):
    in_shape = [2, 20, 30, 3]
    out_shape = [2, 20, 30, 3]

    for nptype in self.TYPES:
      x = np.random.randint(0, high=255, size=[2, 20, 30, 3]).astype(nptype)
      rgb_input_tensor = constant_op.constant(x, shape=in_shape)
      hsv_out = gen_image_ops.rgb_to_hsv(rgb_input_tensor)
      with self.cached_session():
        self.assertEqual(out_shape, list(hsv_out.get_shape()))
      hsv_out = self.evaluate(hsv_out)
      self.assertEqual(out_shape, list(hsv_out.shape))
Beispiel #2
0
def adjust_hue(image, delta, name=None):
  """Adjust hue of an RGB image.

  This is a convenience method that converts an RGB image to float
  representation, converts it to HSV, add an offset to the hue channel, converts
  back to RGB and then back to the original data type. If several adjustments
  are chained it is advisable to minimize the number of redundant conversions.

  `image` is an RGB image.  The image hue is adjusted by converting the
  image to HSV and rotating the hue channel (H) by
  `delta`.  The image is then converted back to RGB.

  `delta` must be in the interval `[-1, 1]`.

  Args:
    image: RGB image or images. Size of the last dimension must be 3.
    delta: float.  How much to add to the hue channel.
    name: A name for this operation (optional).

  Returns:
    Adjusted image(s), same shape and DType as `image`.
  """
  with ops.name_scope(name, 'adjust_hue', [image]) as name:
    image = ops.convert_to_tensor(image, name='image')
    # Remember original dtype to so we can convert back if needed
    orig_dtype = image.dtype
    flt_image = convert_image_dtype(image, dtypes.float32)

    # TODO(zhengxq): we will switch to the fused version after we add a GPU
    # kernel for that.
    fused = os.environ.get('TF_ADJUST_HUE_FUSED', '')
    fused = fused.lower() in ('true', 't', '1')

    if not fused:
      hsv = gen_image_ops.rgb_to_hsv(flt_image)

      hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
      saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
      value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])

      # Note that we add 2*pi to guarantee that the resulting hue is a positive
      # floating point number since delta is [-0.5, 0.5].
      hue = math_ops.mod(hue + (delta + 1.), 1.)

      hsv_altered = array_ops.concat_v2([hue, saturation, value], 2)
      rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)
    else:
      rgb_altered = gen_image_ops.adjust_hue(flt_image, delta)

    return convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #3
0
def adjust_saturation(image, saturation_factor, name=None):
  """Adjust saturation of an RGB image.

  This is a convenience method that converts an RGB image to float
  representation, converts it to HSV, add an offset to the saturation channel,
  converts back to RGB and then back to the original data type. If several
  adjustments are chained it is advisable to minimize the number of redundant
  conversions.

  `image` is an RGB image.  The image saturation is adjusted by converting the
  image to HSV and multiplying the saturation (S) channel by
  `saturation_factor` and clipping. The image is then converted back to RGB.

  Args:
    image: RGB image or images. Size of the last dimension must be 3.
    saturation_factor: float. Factor to multiply the saturation by.
    name: A name for this operation (optional).

  Returns:
    Adjusted image(s), same shape and DType as `image`.
  """
  with ops.name_scope(name, 'adjust_saturation', [image]) as name:
    image = ops.convert_to_tensor(image, name='image')
    # Remember original dtype to so we can convert back if needed
    orig_dtype = image.dtype
    flt_image = convert_image_dtype(image, dtypes.float32)

    # TODO(zhengxq): we will switch to the fused version after we add a GPU
    # kernel for that.
    fused = os.environ.get('TF_ADJUST_SATURATION_FUSED', '')
    fused = fused.lower() in ('true', 't', '1')

    if fused:
      return convert_image_dtype(
          gen_image_ops.adjust_saturation(flt_image, saturation_factor),
          orig_dtype)

    hsv = gen_image_ops.rgb_to_hsv(flt_image)

    hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
    saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
    value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])

    saturation *= saturation_factor
    saturation = clip_ops.clip_by_value(saturation, 0.0, 1.0)

    hsv_altered = array_ops.concat([hue, saturation, value], 2)
    rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

    return convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #4
0
def adjust_saturation(image, saturation_factor, name=None):
  """Adjust saturation of an RGB image.

  This is a convenience method that converts an RGB image to float
  representation, converts it to HSV, add an offset to the saturation channel,
  converts back to RGB and then back to the original data type. If several
  adjustments are chained it is advisable to minimize the number of redundant
  conversions.

  `image` is an RGB image.  The image saturation is adjusted by converting the
  image to HSV and multiplying the saturation (S) channel by
  `saturation_factor` and clipping. The image is then converted back to RGB.

  Args:
    image: RGB image or images. Size of the last dimension must be 3.
    saturation_factor: float. Factor to multiply the saturation by.
    name: A name for this operation (optional).

  Returns:
    Adjusted image(s), same shape and DType as `image`.
  """
  with ops.name_scope(name, 'adjust_saturation', [image]) as name:
    image = ops.convert_to_tensor(image, name='image')
    # Remember original dtype to so we can convert back if needed
    orig_dtype = image.dtype
    flt_image = convert_image_dtype(image, dtypes.float32)

    # TODO(zhengxq): we will switch to the fused version after we add a GPU
    # kernel for that.
    fused = os.environ.get('TF_ADJUST_SATURATION_FUSED', '')
    fused = fused.lower() in ('true', 't', '1')

    if fused:
      return convert_image_dtype(
          gen_image_ops.adjust_saturation(flt_image, saturation_factor),
          orig_dtype)

    hsv = gen_image_ops.rgb_to_hsv(flt_image)

    hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
    saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
    value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])

    saturation *= saturation_factor
    saturation = clip_ops.clip_by_value(saturation, 0.0, 1.0)

    hsv_altered = array_ops.concat([hue, saturation, value], 2)
    rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

    return convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #5
0
def adjust_saturation(image, saturation_factor, name=None):
    with ops.op_scope([image], name, 'adjust_saturation') as name:
        # Remember original dtype to so we can convert back if needed
        orig_dtype = image.dtype
        flt_image = tf.image.convert_image_dtype(image, tf.float32)

        hsv = gen_image_ops.rgb_to_hsv(flt_image)

        hue = tf.slice(hsv, [0, 0, 0, 0], [-1, -1, -1, 1])
        saturation = tf.slice(hsv, [0, 0, 0, 1], [-1, -1, -1, 1])
        value = tf.slice(hsv, [0, 0, 0, 2], [-1, -1, -1, 1])

        saturation *= saturation_factor
        saturation = clip_ops.clip_by_value(saturation, 0.0, 1.0)

        hsv_altered = tf.concat(3, [hue, saturation, value])
        rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

        return tf.image.convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #6
0
def adjust_saturation(image, saturation_factor, name=None):
    with ops.op_scope([image], name, 'adjust_saturation') as name:
        # Remember original dtype to so we can convert back if needed
        orig_dtype = image.dtype
        flt_image = tf.image.convert_image_dtype(image, tf.float32)

        hsv = gen_image_ops.rgb_to_hsv(flt_image)

        hue = tf.slice(hsv, [0, 0, 0, 0], [-1, -1, -1, 1])
        saturation = tf.slice(hsv, [0, 0, 0, 1], [-1, -1, -1, 1])
        value = tf.slice(hsv, [0, 0, 0, 2], [-1, -1, -1, 1])

        saturation *= saturation_factor
        saturation = clip_ops.clip_by_value(saturation, 0.0, 1.0)

        hsv_altered = tf.concat(3, [hue, saturation, value])
        rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

        return tf.image.convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #7
0
def adjust_hue(image, delta, name=None):
    with ops.op_scope([image], name, 'adjust_hue') as name:
        # Remember original dtype to so we can convert back if needed
        orig_dtype = image.dtype
        flt_image = tf.image.convert_image_dtype(image, tf.float32)

        hsv = gen_image_ops.rgb_to_hsv(flt_image)

        hue = tf.slice(hsv, [0, 0, 0, 0], [-1, -1, -1, 1])
        saturation = tf.slice(hsv, [0, 0, 0, 1], [-1, -1, -1, 1])
        value = tf.slice(hsv, [0, 0, 0, 2], [-1, -1, -1, 1])

        # Note that we add 2*pi to guarantee that the resulting hue is a positive
        # floating point number since delta is [-0.5, 0.5].
        hue = math_ops.mod(hue + (delta + 1.), 1.)

        hsv_altered = tf.concat(3, [hue, saturation, value])
        rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

        return tf.image.convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #8
0
def adjust_hue(image, delta, name=None):
    with ops.op_scope([image], name, 'adjust_hue') as name:
        # Remember original dtype to so we can convert back if needed
        orig_dtype = image.dtype
        flt_image = tf.image.convert_image_dtype(image, tf.float32)

        hsv = gen_image_ops.rgb_to_hsv(flt_image)

        hue = tf.slice(hsv, [0, 0, 0, 0], [-1, -1, -1, 1])
        saturation = tf.slice(hsv, [0, 0, 0, 1], [-1, -1, -1, 1])
        value = tf.slice(hsv, [0, 0, 0, 2], [-1, -1, -1, 1])

        # Note that we add 2*pi to guarantee that the resulting hue is a positive
        # floating point number since delta is [-0.5, 0.5].
        hue = math_ops.mod(hue + (delta + 1.), 1.)

        hsv_altered = tf.concat(3, [hue, saturation, value])
        rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

        return tf.image.convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #9
0
def adjust_hue(image, delta, name=None):
    """Adjust hue of an RGB image.

  This is a convenience method that converts an RGB image to float
  representation, converts it to HSV, add an offset to the hue channel, converts
  back to RGB and then back to the original data type. If several adjustments
  are chained it is advisable to minimize the number of redundant conversions.

  `image` is an RGB image.  The image hue is adjusted by converting the
  image to HSV and rotating the hue channel (H) by
  `delta`.  The image is then converted back to RGB.

  `delta` must be in the interval `[-1, 1]`.

  Args:
    image: RGB image or images. Size of the last dimension must be 3.
    delta: float.  How much to add to the hue channel.
    name: A name for this operation (optional).

  Returns:
    Adjusted image(s), same shape and DType as `image`.
  """
    with ops.op_scope([image], name, 'adjust_hue') as name:
        # Remember original dtype to so we can convert back if needed
        orig_dtype = image.dtype
        flt_image = convert_image_dtype(image, dtypes.float32)

        hsv = gen_image_ops.rgb_to_hsv(flt_image)

        hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
        saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
        value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])

        # Note that we add 2*pi to guarantee that the resulting hue is a positive
        # floating point number since delta is [-0.5, 0.5].
        hue = math_ops.mod(hue + (delta + 1.), 1.)

        hsv_altered = array_ops.concat(2, [hue, saturation, value])
        rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

        return convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #10
0
def adjust_hue(image, delta, name=None):
  """Adjust hue of an RGB image.

  This is a convenience method that converts an RGB image to float
  representation, converts it to HSV, add an offset to the hue channel, converts
  back to RGB and then back to the original data type. If several adjustments
  are chained it is advisable to minimize the number of redundant conversions.

  `image` is an RGB image.  The image hue is adjusted by converting the
  image to HSV and rotating the hue channel (H) by
  `delta`.  The image is then converted back to RGB.

  `delta` must be in the interval `[-1, 1]`.

  Args:
    image: RGB image or images. Size of the last dimension must be 3.
    delta: float.  How much to add to the hue channel.
    name: A name for this operation (optional).

  Returns:
    Adjusted image(s), same shape and DType as `image`.
  """
  with ops.op_scope([image], name, 'adjust_hue') as name:
    # Remember original dtype to so we can convert back if needed
    orig_dtype = image.dtype
    flt_image = convert_image_dtype(image, dtypes.float32)

    hsv = gen_image_ops.rgb_to_hsv(flt_image)

    hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
    saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
    value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])

    # Note that we add 2*pi to guarantee that the resulting hue is a positive
    # floating point number since delta is [-0.5, 0.5].
    hue = math_ops.mod(hue + (delta + 1.), 1.)

    hsv_altered = array_ops.concat(2, [hue, saturation, value])
    rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

    return convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #11
0
def adjust_saturation(image, saturation_factor, name=None):
  """Adjust saturation of an RGB image.

  This is a convenience method that converts an RGB image to float
  representation, converts it to HSV, add an offset to the saturation channel,
  converts back to RGB and then back to the original data type. If several
  adjustments are chained it is advisable to minimize the number of redundant
  conversions.

  `image` is an RGB image.  The image saturation is adjusted by converting the
  image to HSV and multiplying the saturation (S) channel by
  `saturation_factor` and clipping. The image is then converted back to RGB.

  Args:
    image: RGB image or images. Size of the last dimension must be 3.
    saturation_factor: float. Factor to multiply the saturation by.
    name: A name for this operation (optional).

  Returns:
    Adjusted image(s), same shape and DType as `image`.
  """
  with ops.op_scope([image], name, 'adjust_saturation') as name:
    image = ops.convert_to_tensor(image, name='image')
    # Remember original dtype to so we can convert back if needed
    orig_dtype = image.dtype
    flt_image = convert_image_dtype(image, dtypes.float32)

    hsv = gen_image_ops.rgb_to_hsv(flt_image)

    hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
    saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
    value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])

    saturation *= saturation_factor
    saturation = clip_ops.clip_by_value(saturation, 0.0, 1.0)

    hsv_altered = array_ops.concat(2, [hue, saturation, value])
    rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

    return convert_image_dtype(rgb_altered, orig_dtype)
Beispiel #12
0
def adjust_saturation(image, saturation_factor, name=None):
    """Adjust saturation of an RGB image.

  This is a convenience method that converts an RGB image to float
  representation, converts it to HSV, add an offset to the saturation channel,
  converts back to RGB and then back to the original data type. If several
  adjustments are chained it is advisable to minimize the number of redundant
  conversions.

  `image` is an RGB image.  The image saturation is adjusted by converting the
  image to HSV and multiplying the saturation (S) channel by
  `saturation_factor` and clipping. The image is then converted back to RGB.

  Args:
    image: RGB image or images. Size of the last dimension must be 3.
    saturation_factor: float. Factor to multiply the saturation by.
    name: A name for this operation (optional).

  Returns:
    Adjusted image(s), same shape and DType as `image`.
  """
    with ops.op_scope([image], name, 'adjust_saturation') as name:
        image = ops.convert_to_tensor(image, name='image')
        # Remember original dtype to so we can convert back if needed
        orig_dtype = image.dtype
        flt_image = convert_image_dtype(image, dtypes.float32)

        hsv = gen_image_ops.rgb_to_hsv(flt_image)

        hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
        saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
        value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])

        saturation *= saturation_factor
        saturation = clip_ops.clip_by_value(saturation, 0.0, 1.0)

        hsv_altered = array_ops.concat(2, [hue, saturation, value])
        rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

        return convert_image_dtype(rgb_altered, orig_dtype)
 def f(x):
     return gen_image_ops.rgb_to_hsv(x)
def main(margin, batch_size, output_size, learning_rate, whichGPU,
         is_finetuning, pretrained_net):
    def handler(signum, frame):
        print 'Saving checkpoint before closing'
        pretrained_net = os.path.join(ckpt_dir, 'checkpoint-' + param_str)
        saver.save(sess, pretrained_net, global_step=step)
        print 'Checkpoint-', pretrained_net + '-' + str(step), ' saved!'
        sys.exit(0)

    signal.signal(signal.SIGINT, handler)

    ckpt_dir = './output/sameChain/tcam/ckpts'
    log_dir = './output/sameChain/tcam/logs'
    train_filename = './input/train_by_chain.txt'
    mean_file = './input/meanIm.npy'

    img_size = [256, 256]
    crop_size = [224, 224]
    num_iters = 200000
    summary_iters = 100
    save_iters = 5000
    featLayer = 'resnet_v2_50/logits'

    is_training = True

    margin = float(margin)
    batch_size = int(batch_size)
    output_size = int(output_size)
    learning_rate = float(learning_rate)
    whichGPU = str(whichGPU)

    if batch_size % 10 != 0:
        print 'Batch size must be divisible by 10!'
        sys.exit(0)

    num_pos_examples = batch_size / 10

    # Create data "batcher"
    train_data = SameClassSet(train_filename,
                              mean_file,
                              img_size,
                              crop_size,
                              batch_size,
                              num_pos_examples,
                              isTraining=is_training)
    datestr = datetime.now().strftime("%Y_%m_%d_%H%M")
    param_str = datestr + '_tcam_with_doctoring_lr' + str(
        learning_rate).replace('.', 'pt') + '_outputSz' + str(
            output_size) + '_margin' + str(margin).replace('.', 'pt')
    logfile_path = os.path.join(log_dir, param_str + '_train.txt')
    train_log_file = open(logfile_path, 'a')
    print '------------'
    print ''
    print 'Going to train with the following parameters:'
    print 'Margin: ', margin
    train_log_file.write('Margin: ' + str(margin) + '\n')
    print 'Output size: ', output_size
    train_log_file.write('Output size: ' + str(output_size) + '\n')
    print 'Learning rate: ', learning_rate
    train_log_file.write('Learning rate: ' + str(learning_rate) + '\n')
    print 'Logging to: ', logfile_path
    train_log_file.write('Param_str: ' + param_str + '\n')
    train_log_file.write('----------------\n')
    print ''
    print '------------'

    # Queuing op loads data into input tensor
    image_batch = tf.placeholder(
        tf.float32, shape=[batch_size, crop_size[0], crop_size[0], 3])
    people_mask_batch = tf.placeholder(
        tf.float32, shape=[batch_size, crop_size[0], crop_size[0], 1])

    # doctor image params
    percent_crop = .5
    percent_people = .5
    percent_rotate = .2
    percent_filters = .4
    percent_text = .1

    # # richard's argument: since the data is randomly loaded, we don't need to change the indices that we perform operations on every time; i am on board with this, but had already implemented the random crops, so will leave that for now
    # # apply random rotations
    num_rotate = int(batch_size * percent_rotate)
    rotate_inds = np.random.choice(np.arange(0, batch_size),
                                   num_rotate,
                                   replace=False)
    rotate_vals = np.random.randint(-65, 65,
                                    num_rotate).astype('float32') / float(100)
    rotate_angles = np.zeros((batch_size))
    rotate_angles[rotate_inds] = rotate_vals
    rotated_batch = tf.contrib.image.rotate(image_batch,
                                            rotate_angles,
                                            interpolation='BILINEAR')

    # do random crops
    num_to_crop = int(batch_size * percent_crop)
    num_to_not_crop = batch_size - num_to_crop

    shuffled_inds = tf.random_shuffle(np.arange(0, batch_size, dtype='int32'))
    # shuffled_inds = np.arange(0,batch_size,dtype='int32')
    # np.random.shuffle(shuffled_inds)
    crop_inds = tf.slice(shuffled_inds, [0], [num_to_crop])
    uncropped_inds = tf.slice(shuffled_inds, [num_to_crop], [num_to_not_crop])

    # crop_ratio = float(3)/float(5)
    # crop_yx = tf.random_uniform([num_to_crop,2], 0,1-crop_ratio, dtype=tf.float32, seed=0)
    # crop_sz = tf.add(crop_yx,np.tile([crop_ratio,crop_ratio],[num_to_crop, 1]))
    # crop_boxes = tf.concat([crop_yx,crop_sz],axis=1)

    # randomly select a crop between 3/5 of the image and the entire image
    crop_ratio = tf.random_uniform([num_to_crop, 1],
                                   float(3) / float(5),
                                   1,
                                   dtype=tf.float32,
                                   seed=0)
    # randomly select a starting location between 0 and the max valid x position
    crop_yx = tf.random_uniform([1, 2],
                                0.,
                                1. - crop_ratio,
                                dtype=tf.float32,
                                seed=0)
    crop_sz = tf.add(crop_yx, tf.concat([crop_ratio, crop_ratio], axis=1))
    crop_boxes = tf.concat([crop_yx, crop_sz], axis=1)

    uncropped_boxes = np.tile([0, 0, 1, 1], [num_to_not_crop, 1])

    all_inds = tf.concat([crop_inds, uncropped_inds], axis=0)
    all_boxes = tf.concat([crop_boxes, uncropped_boxes], axis=0)

    sorted_inds = tf.nn.top_k(-shuffled_inds, sorted=True,
                              k=batch_size).indices
    cropped_batch = tf.gather(
        tf.image.crop_and_resize(rotated_batch, all_boxes, all_inds,
                                 crop_size), sorted_inds)

    # apply different filters
    flt_image = convert_image_dtype(cropped_batch, dtypes.float32)

    num_to_filter = int(batch_size * percent_filters)

    filter_inds = np.random.choice(np.arange(0, batch_size),
                                   num_to_filter,
                                   replace=False)
    filter_mask = np.zeros(batch_size)
    filter_mask[filter_inds] = 1
    filter_mask = filter_mask.astype('float32')
    inv_filter_mask = np.ones(batch_size)
    inv_filter_mask[filter_inds] = 0
    inv_filter_mask = inv_filter_mask.astype('float32')

    #
    hsv = gen_image_ops.rgb_to_hsv(flt_image)
    hue = array_ops.slice(hsv, [0, 0, 0, 0], [batch_size, -1, -1, 1])
    saturation = array_ops.slice(hsv, [0, 0, 0, 1], [batch_size, -1, -1, 1])
    value = array_ops.slice(hsv, [0, 0, 0, 2], [batch_size, -1, -1, 1])

    # hue
    delta_vals = random_ops.random_uniform([batch_size], -.15, .15)
    hue_deltas = tf.multiply(filter_mask, delta_vals)
    hue_deltas2 = tf.expand_dims(
        tf.transpose(
            tf.tile(tf.reshape(hue_deltas, [1, 1, batch_size]),
                    (crop_size[0], crop_size[1], 1)), (2, 0, 1)), 3)
    # hue = math_ops.mod(hue + (hue_deltas2 + 1.), 1.)
    hue_mod = tf.add(hue, hue_deltas2)
    hue = clip_ops.clip_by_value(hue_mod, 0.0, 1.0)

    # saturation
    saturation_factor = random_ops.random_uniform([batch_size], -.05, .05)
    saturation_factor2 = tf.multiply(filter_mask, saturation_factor)
    saturation_factor3 = tf.expand_dims(
        tf.transpose(
            tf.tile(tf.reshape(saturation_factor2, [1, 1, batch_size]),
                    (crop_size[0], crop_size[1], 1)), (2, 0, 1)), 3)
    saturation_mod = tf.add(saturation, saturation_factor3)
    saturation = clip_ops.clip_by_value(saturation_mod, 0.0, 1.0)

    hsv_altered = array_ops.concat([hue, saturation, value], 3)
    rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

    # brightness
    brightness_factor = random_ops.random_uniform([batch_size], -.25, .25)
    brightness_factor2 = tf.multiply(filter_mask, brightness_factor)
    brightness_factor3 = tf.expand_dims(
        tf.transpose(
            tf.tile(tf.reshape(brightness_factor2, [1, 1, batch_size]),
                    (crop_size[0], crop_size[1], 1)), (2, 0, 1)), 3)
    adjusted = math_ops.add(rgb_altered,
                            math_ops.cast(brightness_factor3, dtypes.float32))

    filtered_batch = clip_ops.clip_by_value(adjusted, 0.0, 255.0)

    # insert people masks
    num_people_masks = int(batch_size * percent_people)
    mask_inds = np.random.choice(np.arange(0, batch_size),
                                 num_people_masks,
                                 replace=False)

    start_masks = np.zeros([batch_size, crop_size[0], crop_size[0], 1],
                           dtype='float32')
    start_masks[mask_inds, :, :, :] = 1

    inv_start_masks = np.ones([batch_size, crop_size[0], crop_size[0], 1],
                              dtype='float32')
    inv_start_masks[mask_inds, :, :, :] = 0

    masked_masks = tf.add(
        inv_start_masks,
        tf.cast(tf.multiply(people_mask_batch, start_masks), dtype=tf.float32))
    masked_masks2 = tf.cast(tf.tile(masked_masks, [1, 1, 1, 3]),
                            dtype=tf.float32)
    masked_batch = tf.multiply(masked_masks, filtered_batch)

    noise = tf.random_normal(shape=[batch_size, crop_size[0], crop_size[0], 1],
                             mean=0.0,
                             stddev=0.0025,
                             dtype=tf.float32)
    final_batch = tf.add(masked_batch, noise)

    print("Preparing network...")
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
        _, layers = resnet_v2.resnet_v2_50(final_batch,
                                           num_classes=output_size,
                                           is_training=True)

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        if is_finetuning.lower() == 'true' and var.op.name.startswith(
                'resnet_v2_50/logits') or 'momentum' in var.op.name.lower():
            excluded = True
        if not excluded:
            variables_to_restore.append(var)

    feat = tf.squeeze(tf.nn.l2_normalize(layers[featLayer], 3))
    expanded_a = tf.expand_dims(feat, 1)
    expanded_b = tf.expand_dims(feat, 0)
    #D = tf.reduce_sum(tf.squared_difference(expanded_a, expanded_b), 2)
    D = 1 - tf.reduce_sum(tf.multiply(expanded_a, expanded_b), 2)

    # if not train_data.isOverfitting:
    #     D_max = tf.reduce_max(D)
    #     D_mean, D_var = tf.nn.moments(D, axes=[0,1])
    #     lowest_nonzero_distance = tf.reduce_max(-D)
    #     bottom_thresh = 1.2*lowest_nonzero_distance
    #     top_thresh = (D_max + D_mean)/2.0
    #     bool_mask = tf.logical_and(D>=bottom_thresh,D<=top_thresh)
    #     D = tf.multiply(D,tf.cast(bool_mask,tf.float32))

    posIdx = np.floor(np.arange(0, batch_size) /
                      num_pos_examples).astype('int')
    posIdx10 = num_pos_examples * posIdx
    posImInds = np.tile(posIdx10, (num_pos_examples, 1)).transpose() + np.tile(
        np.arange(0, num_pos_examples), (batch_size, 1))
    anchorInds = np.tile(np.arange(0, batch_size),
                         (num_pos_examples, 1)).transpose()

    posImInds_flat = posImInds.ravel()
    anchorInds_flat = anchorInds.ravel()

    posPairInds = zip(posImInds_flat, anchorInds_flat)
    posDists = tf.reshape(tf.gather_nd(D, posPairInds),
                          (batch_size, num_pos_examples))

    shiftPosDists = tf.reshape(posDists, (1, batch_size, num_pos_examples))
    posDistsRep = tf.tile(shiftPosDists, (batch_size, 1, 1))

    allDists = tf.tile(tf.expand_dims(D, 2), (1, 1, num_pos_examples))

    ra, rb, rc = np.meshgrid(np.arange(0, batch_size),
                             np.arange(0, batch_size),
                             np.arange(0, num_pos_examples))

    bad_negatives = np.floor((ra) / num_pos_examples) == np.floor(
        (rb) / num_pos_examples)
    bad_positives = np.mod(rb,
                           num_pos_examples) == np.mod(rc, num_pos_examples)

    mask = ((1 - bad_negatives) * (1 - bad_positives)).astype('float32')

    # loss = tf.reduce_sum(tf.maximum(0.,tf.multiply(mask,margin + posDistsRep - allDists)))/batch_size
    loss = tf.reduce_mean(
        tf.maximum(0., tf.multiply(mask, margin + posDistsRep - allDists)))

    # slightly counterintuitive to not define "init_op" first, but tf vars aren't known until added to graph
    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # with tf.control_dependencies(update_ops):
    # train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
    # optimizer = tf.train.AdamOptimizer(learning_rate)
    # train_op = slim.learning.create_train_op(loss, optimizer)
    train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver(max_to_keep=2000)

    # tf will consume any GPU it finds on the system. Following lines restrict it to specific gpus
    c = tf.ConfigProto()
    c.gpu_options.visible_device_list = whichGPU

    print("Starting session...")
    sess = tf.Session(config=c)
    sess.run(init_op)

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    restore_fn = slim.assign_from_checkpoint_fn(pretrained_net,
                                                variables_to_restore)
    restore_fn(sess)

    print("Start training...")
    ctr = 0
    for step in range(num_iters):
        start_time = time.time()
        batch, labels, ims = train_data.getBatch()
        people_masks = train_data.getPeopleMasks()
        _, loss_val = sess.run([train_op, loss],
                               feed_dict={
                                   image_batch: batch,
                                   people_mask_batch: people_masks
                               })
        end_time = time.time()
        duration = end_time - start_time
        out_str = 'Step %d: loss = %.6f -- (%.3f sec)' % (step, loss_val,
                                                          duration)
        # print(out_str)
        if step % summary_iters == 0:
            print(out_str)
            train_log_file.write(out_str + '\n')
        # Update the events file.
        # summary_str = sess.run(summary_op)
        # writer.add_summary(summary_str, step)
        # writer.flush()
        #
        # Save a checkpoint
        if (step + 1) % save_iters == 0:
            print('Saving checkpoint at iteration: %d' % (step))
            pretrained_net = os.path.join(ckpt_dir, 'checkpoint-' + param_str)
            saver.save(sess, pretrained_net, global_step=step)
            print 'checkpoint-', pretrained_net + '-' + str(step), ' saved!'
        if (step + 1) == num_iters:
            print('Saving final')
            pretrained_net = os.path.join(ckpt_dir, 'final-' + param_str)
            saver.save(sess, pretrained_net, global_step=step)
            print 'final-', pretrained_net + '-' + str(step), ' saved!'

    sess.close()
    train_log_file.close()
def main(batch_size, output_size, learning_rate, whichGPU, is_finetuning,
         is_overfitting, pretrained_net):
    def handler(signum, frame):
        print 'Saving checkpoint before closing'
        pretrained_net = os.path.join(ckpt_dir, 'checkpoint-' + param_str)
        saver.save(sess, pretrained_net, global_step=step)
        print 'Checkpoint-', pretrained_net + '-' + str(step), ' saved!'
        sys.exit(0)

    signal.signal(signal.SIGINT, handler)

    ckpt_dir = './output/npairs/doctoring/ckpts'
    log_dir = './output/npairs/doctoring/logs'
    train_filename = './input/train_by_hotel.txt'
    mean_file = './input/meanIm.npy'

    img_size = [256, 256]
    crop_size = [224, 224]
    num_iters = 200000
    summary_iters = 25
    save_iters = 5000
    featLayer = 'resnet_v2_50/logits'

    is_training = True

    batch_size = int(batch_size)
    output_size = int(output_size)
    learning_rate = float(learning_rate)
    whichGPU = str(whichGPU)

    if batch_size % 10 != 0:
        print 'Batch size must be divisible by 10!'
        sys.exit(0)

    # Create data "batcher"
    train_data = Npairs(train_filename,
                        mean_file,
                        img_size,
                        crop_size,
                        batch_size,
                        isTraining=is_training)

    numHotels = len(train_data.hotels.keys())
    numIms = np.sum(
        [len(train_data.hotels[h]['ims']) for h in train_data.hotels.keys()])

    datestr = datetime.now().strftime("%Y_%m_%d_%H%M")
    param_str = datestr + '_lr' + str(learning_rate).replace(
        '.', 'pt') + '_outputSz' + str(output_size)
    logfile_path = os.path.join(log_dir, param_str + '_npairs_train.txt')
    train_log_file = open(logfile_path, 'a')
    print '------------'
    print ''
    print 'Going to train with the following parameters:'
    print 'Num hotels:', numHotels
    train_log_file.write('Num hotels: ' + str(numHotels) + '\n')
    print 'Num ims:', numIms
    train_log_file.write('Num ims: ' + str(numIms) + '\n')
    print 'Output size: ', output_size
    train_log_file.write('Output size: ' + str(output_size) + '\n')
    print 'Learning rate: ', learning_rate
    train_log_file.write('Learning rate: ' + str(learning_rate) + '\n')
    print 'Logging to: ', logfile_path
    train_log_file.write('Param_str: ' + param_str + '\n')
    train_log_file.write('----------------\n')
    print ''
    print '------------'

    # Queuing op loads data into input tensor
    repMeanIm = np.tile(np.expand_dims(train_data.meanImage, 0),
                        [batch_size, 1, 1, 1])
    # this is dumb, but in the non-doctored case we subtract off the mean in the batch generation. here we want to do it after the data augmentation
    image_batch_mean_subtracted = tf.placeholder(
        tf.float32, shape=[batch_size, crop_size[0], crop_size[0], 3])
    image_batch = tf.add(image_batch_mean_subtracted, repMeanIm)
    label_batch = tf.placeholder(tf.int32, shape=[batch_size])
    people_mask_batch = tf.placeholder(
        tf.float32, shape=[batch_size, crop_size[0], crop_size[0], 1])

    # doctor image params
    percent_crop = .5
    percent_people = .5
    percent_rotate = .2
    percent_filters = .4
    percent_text = .1

    # # richard's argument: since the data is randomly loaded, we don't need to change the indices that we perform operations on every time; i am on board with this, but had already implemented the random crops, so will leave that for now
    # # apply random rotations
    num_rotate = int(batch_size * percent_rotate)
    rotate_inds = np.random.choice(np.arange(0, batch_size),
                                   num_rotate,
                                   replace=False)
    rotate_vals = np.random.randint(-65, 65,
                                    num_rotate).astype('float32') / float(100)
    rotate_angles = np.zeros((batch_size))
    rotate_angles[rotate_inds] = rotate_vals
    rotated_batch = tf.contrib.image.rotate(image_batch,
                                            rotate_angles,
                                            interpolation='BILINEAR')

    # do random crops
    num_to_crop = int(batch_size * percent_crop)
    num_to_not_crop = batch_size - num_to_crop

    shuffled_inds = tf.random_shuffle(np.arange(0, batch_size, dtype='int32'))
    # shuffled_inds = np.arange(0,batch_size,dtype='int32')
    # np.random.shuffle(shuffled_inds)
    crop_inds = tf.slice(shuffled_inds, [0], [num_to_crop])
    uncropped_inds = tf.slice(shuffled_inds, [num_to_crop], [num_to_not_crop])

    # crop_ratio = float(3)/float(5)
    # crop_yx = tf.random_uniform([num_to_crop,2], 0,1-crop_ratio, dtype=tf.float32, seed=0)
    # crop_sz = tf.add(crop_yx,np.tile([crop_ratio,crop_ratio],[num_to_crop, 1]))
    # crop_boxes = tf.concat([crop_yx,crop_sz],axis=1)

    # randomly select a crop between 3/5 of the image and the entire image
    crop_ratio = tf.random_uniform([num_to_crop, 1],
                                   float(3) / float(5),
                                   1,
                                   dtype=tf.float32,
                                   seed=0)
    # randomly select a starting location between 0 and the max valid x position
    crop_yx = tf.random_uniform([1, 2],
                                0.,
                                1. - crop_ratio,
                                dtype=tf.float32,
                                seed=0)
    crop_sz = tf.add(crop_yx, tf.concat([crop_ratio, crop_ratio], axis=1))
    crop_boxes = tf.concat([crop_yx, crop_sz], axis=1)

    uncropped_boxes = np.tile([0, 0, 1, 1], [num_to_not_crop, 1])

    all_inds = tf.concat([crop_inds, uncropped_inds], axis=0)
    all_boxes = tf.concat([crop_boxes, uncropped_boxes], axis=0)

    sorted_inds = tf.nn.top_k(-shuffled_inds, sorted=True,
                              k=batch_size).indices
    cropped_batch = tf.gather(
        tf.image.crop_and_resize(rotated_batch, all_boxes, all_inds,
                                 crop_size), sorted_inds)

    # apply different filters
    flt_image = convert_image_dtype(cropped_batch, dtypes.float32)

    num_to_filter = int(batch_size * percent_filters)

    filter_inds = np.random.choice(np.arange(0, batch_size),
                                   num_to_filter,
                                   replace=False)
    filter_mask = np.zeros(batch_size)
    filter_mask[filter_inds] = 1
    filter_mask = filter_mask.astype('float32')
    inv_filter_mask = np.ones(batch_size)
    inv_filter_mask[filter_inds] = 0
    inv_filter_mask = inv_filter_mask.astype('float32')

    #
    hsv = gen_image_ops.rgb_to_hsv(flt_image)
    hue = array_ops.slice(hsv, [0, 0, 0, 0], [batch_size, -1, -1, 1])
    saturation = array_ops.slice(hsv, [0, 0, 0, 1], [batch_size, -1, -1, 1])
    value = array_ops.slice(hsv, [0, 0, 0, 2], [batch_size, -1, -1, 1])

    # hue
    delta_vals = random_ops.random_uniform([batch_size], -.15, .15)
    hue_deltas = tf.multiply(filter_mask, delta_vals)
    hue_deltas2 = tf.expand_dims(
        tf.transpose(
            tf.tile(tf.reshape(hue_deltas, [1, 1, batch_size]),
                    (crop_size[0], crop_size[1], 1)), (2, 0, 1)), 3)
    # hue = math_ops.mod(hue + (hue_deltas2 + 1.), 1.)
    hue_mod = tf.add(hue, hue_deltas2)
    hue = clip_ops.clip_by_value(hue_mod, 0.0, 1.0)

    # saturation
    saturation_factor = random_ops.random_uniform([batch_size], -.05, .05)
    saturation_factor2 = tf.multiply(filter_mask, saturation_factor)
    saturation_factor3 = tf.expand_dims(
        tf.transpose(
            tf.tile(tf.reshape(saturation_factor2, [1, 1, batch_size]),
                    (crop_size[0], crop_size[1], 1)), (2, 0, 1)), 3)
    saturation_mod = tf.add(saturation, saturation_factor3)
    saturation = clip_ops.clip_by_value(saturation_mod, 0.0, 1.0)

    hsv_altered = array_ops.concat([hue, saturation, value], 3)
    rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

    # brightness
    brightness_factor = random_ops.random_uniform([batch_size], -.25, .25)
    brightness_factor2 = tf.multiply(filter_mask, brightness_factor)
    brightness_factor3 = tf.expand_dims(
        tf.transpose(
            tf.tile(tf.reshape(brightness_factor2, [1, 1, batch_size]),
                    (crop_size[0], crop_size[1], 1)), (2, 0, 1)), 3)
    adjusted = math_ops.add(rgb_altered,
                            math_ops.cast(brightness_factor3, dtypes.float32))

    filtered_batch = clip_ops.clip_by_value(adjusted, 0.0, 255.0)

    # insert people masks
    num_people_masks = int(batch_size * percent_people)
    mask_inds = np.random.choice(np.arange(0, batch_size),
                                 num_people_masks,
                                 replace=False)

    start_masks = np.zeros([batch_size, crop_size[0], crop_size[0], 1],
                           dtype='float32')
    start_masks[mask_inds, :, :, :] = 1

    inv_start_masks = np.ones([batch_size, crop_size[0], crop_size[0], 1],
                              dtype='float32')
    inv_start_masks[mask_inds, :, :, :] = 0

    masked_masks = tf.add(
        inv_start_masks,
        tf.cast(tf.multiply(people_mask_batch, start_masks), dtype=tf.float32))
    masked_masks2 = tf.cast(tf.tile(masked_masks, [1, 1, 1, 3]),
                            dtype=tf.float32)
    masked_batch = tf.multiply(masked_masks, filtered_batch)

    noise = tf.random_normal(shape=[batch_size, crop_size[0], crop_size[0], 1],
                             mean=0.0,
                             stddev=0.0025,
                             dtype=tf.float32)
    final_batch = tf.add(tf.subtract(masked_batch, repMeanIm), noise)

    print("Preparing network...")
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
        _, layers = resnet_v2.resnet_v2_50(final_batch,
                                           num_classes=output_size,
                                           is_training=True)

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        if is_finetuning.lower() == 'true' and var.op.name.startswith(
                'resnet_v2_50/logits') or 'momentum' in var.op.name.lower():
            excluded = True
        if not excluded:
            variables_to_restore.append(var)

    # numpy stuff for figuring out which elements are from the same class and which aren't
    anchor_inds = np.arange(0, batch_size, 2)
    pos_inds = np.arange(1, batch_size, 2)

    labels = tf.gather(label_batch, anchor_inds)

    all_feats = tf.squeeze(layers[featLayer])
    anchor_feats = tf.gather(all_feats, anchor_inds)
    pos_feats = tf.gather(all_feats, pos_inds)

    loss = npairs_loss(labels, anchor_feats, pos_feats)

    # slightly counterintuitive to not define "init_op" first, but tf vars aren't known until added to graph
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        # train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        train_op = slim.learning.create_train_op(loss, optimizer)

    summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver(max_to_keep=2000)

    # tf will consume any GPU it finds on the system. Following lines restrict it to specific gpus
    c = tf.ConfigProto()
    c.gpu_options.visible_device_list = whichGPU

    print("Starting session...")
    sess = tf.Session(config=c)
    sess.run(init_op)

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    if pretrained_net.lower() != 'none':
        restore_fn = slim.assign_from_checkpoint_fn(pretrained_net,
                                                    variables_to_restore)
        restore_fn(sess)

    print("Start training...")
    ctr = 0
    for step in range(num_iters):
        start_time = time.time()
        batch, hotels, ims = train_data.getBatch()
        people_masks = train_data.getPeopleMasks()
        batch_time = time.time() - start_time
        start_time = time.time()
        _, fb, loss_val = sess.run(
            [train_op, masked_batch, loss],
            feed_dict={
                image_batch_mean_subtracted: batch,
                label_batch: hotels,
                people_mask_batch: people_masks
            })
        end_time = time.time()
        duration = end_time - start_time
        out_str = 'Step %d: loss = %.6f (batch creation: %.3f | training: %.3f sec)' % (
            step, loss_val, batch_time, duration)
        # print(out_str)
        if step == 0:
            np.save(
                os.path.join(log_dir,
                             'checkpoint-' + param_str + '_example_batch.npy'),
                fb)
        if step % summary_iters == 0 or is_overfitting.lower() == 'true':
            print(out_str)
            train_log_file.write(out_str + '\n')
        # Update the events file.
        # summary_str = sess.run(summary_op)
        # writer.add_summary(summary_str, step)
        # writer.flush()
        #
        # Save a checkpoint
        if (step + 1) % save_iters == 0:
            print('Saving checkpoint at iteration: %d' % (step))
            pretrained_net = os.path.join(ckpt_dir, 'checkpoint-' + param_str)
            saver.save(sess, pretrained_net, global_step=step)
            print 'checkpoint-', pretrained_net + '-' + str(step), ' saved!'
        if (step + 1) == num_iters:
            print('Saving final')
            pretrained_net = os.path.join(ckpt_dir, 'final-' + param_str)
            saver.save(sess, pretrained_net, global_step=step)
            print 'final-', pretrained_net + '-' + str(step), ' saved!'

    sess.close()
    train_log_file.close()