Пример #1
0
    def extract_image(self):
        # Create folder to store extracted images
        folder_path = './ExtractedImages'
        shutil.rmtree(folder_path, ignore_errors=True)
        os.mkdir(folder_path)

        # Pipeline of dataset and iterator
        # print("path: ", self.tfrecord_file)
        files = tf.data.Dataset.list_files(self.tfrecord_file)
        dataset = tf.data.TFRecordDataset(files)
        # dataset.map(_extract_fn, num_parallel_calls=n_cpus() // 2).batch(12)
        # dataset = dataset.prefetch(buffer_size=2)
        dataset = dataset.map(_extract_fn)
        iterator = dataset.make_one_shot_iterator()
        next_image_data = iterator.get_next()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            i = 0
            # try:
            # Keep extracting data till TFRecord is exhausted
            while True:
                image_data = sess.run(next_image_data)
                # img_ = Image.fromarray(image_data['img'], 'RGB')
                # img_.show()
                # print(image_data)

                image_name1 = 'raw' + str(i)
                image_name2 = 'aug' + str(i)
                i = i + 1
                save_path1 = os.path.abspath(
                    os.path.join(folder_path, image_name1))
                save_path2 = os.path.abspath(
                    os.path.join(folder_path, image_name2))

                im_l = tf.concat([image_data['img'], image_data['label']],
                                 axis=-1)
                # # x = tf.image.random_flip_left_right(im_l)
                x = tf.image.random_crop(im_l, [vh, vw, 3 + N_CLASSES])
                images = x[np.newaxis, :, :, :3]
                # labels = x[:, :, :, 3:]

                im_warp = tf.image.random_flip_left_right(images)
                im_warp = layers.rand_warp(im_warp, [vh, vw])
                im_w_adj = tf.clip_by_value(im_warp + \
                                            tf.random.uniform([tf.shape(im_warp)[0], 1, 1, 1], -.8, 0.0),
                                            0.0, 1.0)
                tf.where(tf.less(tf.reduce_mean(im_warp, axis=[1, 2, 3]), 0.2),
                         im_warp, im_w_adj)
                # im_warp_v = tf.Variable(im_warp)
                im_warp_v = layers.random_erasing(im_warp)
                im_warp_v = tf.squeeze(layers.random_erasing(im_warp_v))
                # print(type(im_warp_v.eval()))

                mpimg.imsave(save_path1, image_data['img'])
                # print(im_warp_v.dtype)
                # img_tosave = tf.squeeze(im_warp)
                mpimg.imsave(save_path2, im_warp_v.eval())
Пример #2
0
def model_fn(features, labels, mode, hparams):

    del labels

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    sz = FLAGS.batch_size if is_training else FLAGS.batch_size // 3

    im_l = tf.concat([features['img'], features['label']], axis=-1)
    #x = tf.image.random_flip_left_right(im_l)
    x = tf.image.random_crop(im_l, [tf.shape(im_l)[0], vh, vw, 3 + N_CLASSES])
    features['img'] = x[:, :, :, :3]
    labels = x[:, :, :, 3:]
    if is_training:
        images = features['img']
    else:
        images = tf.concat(
            [features['img'], features['cl_live'], features['cl_mem']], 0)

    im_warp = tf.image.random_flip_left_right(images)
    im_warp = layers.rand_warp(im_warp, [vh, vw])
    im_w_adj = tf.clip_by_value(im_warp + \
            tf.random.uniform([tf.shape(im_warp)[0], 1, 1, 1], -.8, 0.0),
            0.0, 1.0)
    tf.where(tf.less(tf.reduce_mean(im_warp, axis=[1, 2, 3]), 0.2), im_warp,
             im_w_adj)

    mu, log_sig_sq, rec, seg, z, c_centers, descr = vss(images, is_training)
    descr_p = vss(im_warp, is_training, True, True)
    descr_n = utils.hard_neg_mine(descr)

    lp = tf.reduce_sum(descr_p * descr, -1)
    ln = tf.reduce_sum(descr_n * descr, -1)
    m = 0.5
    simloss = tf.reduce_mean(tf.maximum(tf.zeros_like(ln), ln + m - lp))

    #labels = tf.cast(labels, tf.bool)
    #label_ext = tf.concat([tf.expand_dims(labels,-1),
    #            tf.logical_not(tf.expand_dims(labels, -1))], axis=-1)

    if is_training:
        _seg = tf.nn.softmax(seg, axis=-1)
    else:
        _seg = tf.nn.softmax(seg[:FLAGS.batch_size // 3], axis=-1)

    weights = tf.placeholder_with_default(_weights, _weights.shape)
    weights = weights / tf.reduce_min(weights)
    _seg = tf.clip_by_value(_seg, 1e-6, 1.0)
    segloss = tf.reduce_mean(
        -tf.reduce_sum(labels * weights * tf.log(_seg), axis=-1))

    recloss = tf.reduce_mean(-tf.reduce_sum(
        images * tf.log(tf.clip_by_value(rec, 1e-10, 1.0)) +
        (1.0 - images) * tf.log(tf.clip_by_value(1.0 - rec, 1e-10, 1.0)),
        axis=[1, 2, 3]))

    sh = mu.get_shape().as_list()
    nwh = sh[1] * sh[2] * sh[3]
    m = tf.reshape(mu, [-1, nwh])  # [?, 16 * w*h]
    s = tf.reshape(log_sig_sq, [-1, nwh])
    # stdev is the diagonal of the covariance matrix
    # .5 (tr(sigma2) + mu^T mu - k - log det(sigma2))
    kld = tf.reduce_mean(
        -0.5 * (tf.reduce_sum(1.0 + s - tf.square(m) - tf.exp(s), axis=-1)))

    kld = tf.check_numerics(kld, '\n\n\n\nkld is inf or nan!\n\n\n')
    recloss = tf.check_numerics(recloss,
                                '\n\n\n\nrecloss is inf or nan!\n\n\n')
    segloss = tf.check_numerics(segloss,
                                '\n\n\n\nsegloss is inf or nan!\n\n\n')

    loss = segloss + \
            0.0001 * kld + \
            0.0001 * recloss + \
            simloss

    prob = _seg[0, :, :, :]
    pred = tf.argmax(prob, axis=-1)

    mask = tf.argmax(labels[0], axis=-1)
    if not is_training:

        dlive = descr[(FLAGS.batch_size // 3):(2 * FLAGS.batch_size // 3)]
        dmem = descr[(2 * FLAGS.batch_size // 3):]

        # Compare each combination of live to mem
        tlive = tf.tile(
            dlive, [tf.shape(dlive)[0], 1])  # [l0, l1, l2..., l0, l1, l2...]

        tmem = tf.reshape(
            tf.tile(tf.expand_dims(dmem, 1), [1, tf.shape(dlive)[0], 1]),
            [-1, dlive.get_shape().as_list()[1]
             ])  # [m0, m0, m0..., m1, m1, m1...]

        sim = tf.reduce_sum(tlive * tmem,
                            axis=-1)  # Cosine sim for rgb data + class data
        # Average score across rgb + classes. Map from [-1,1] -> [0,1]
        sim = (1.0 + sim) / 2.0

        sim_sq = tf.reshape(sim,
                            [FLAGS.batch_size // 3, FLAGS.batch_size // 3])

        # Correct location is along diagonal
        labm = tf.reshape(tf.eye(FLAGS.batch_size // 3, dtype=tf.int64), [-1])

        # ID of nearest neighbor from
        ids = tf.argmax(sim_sq, axis=-1)

        # I guess just contiguously index it?
        row_inds = tf.range(0, FLAGS.batch_size // 3,
                            dtype=tf.int64) * (FLAGS.batch_size // 3 - 1)
        buffer_inds = row_inds + ids
        sim_nn = tf.nn.embedding_lookup(sim, buffer_inds)
        # Pull out the labels if it was correct (0 or 1)
        lab = tf.nn.embedding_lookup(labm, buffer_inds)

    def touint8(img):
        return tf.cast(img * 255.0, tf.uint8)

    _im = touint8(images[0])
    _rec = touint8(rec[0])

    with tf.variable_scope("stats"):
        tf.summary.scalar("loss", loss)
        tf.summary.scalar("segloss", segloss)
        tf.summary.scalar("kld", kld)
        tf.summary.scalar("recloss", recloss)
        tf.summary.scalar("simloss", simloss)
        tf.summary.histogram("z", z)
        tf.summary.histogram("mu", mu)
        tf.summary.histogram("sig", tf.exp(log_sig_sq))
        tf.summary.histogram("clust_centers", c_centers)

    eval_ops = {
        "Test Error": tf.metrics.mean(loss),
        "Seg Error": tf.metrics.mean(segloss),
        "Rec Error": tf.metrics.mean(recloss),
        "KLD Error": tf.metrics.mean(kld),
        "Sim Error": tf.metrics.mean(simloss),
    }

    if not is_training:
        # Closer to 1 is better
        eval_ops["AUC"] = tf.metrics.auc(lab, sim_nn, curve='PR')
    to_return = {
        "loss": loss,
        "segloss": segloss,
        "recloss": recloss,
        "simloss": simloss,
        "kld": kld,
        "eval_metric_ops": eval_ops,
        'pred': pred,
        'rec': _rec,
        'label': mask,
        'im': _im
    }

    predictions = {'pred': seg, 'rec': rec}

    to_return['predictions'] = predictions

    utils.display_trainable_parameters()

    return to_return