Пример #1
0
    def _train_step(self, x, y, clw):
        """
        Input:
        x - tensor of shape (bs, ps_h, ps_w, c_x)
        y - tensor of shape (bs, ps_h, ps_w, c_y)
        clw - cross_loss_weight, tensor of shape (bs, ps_h, ps_w, 1)
        """
        with tf.GradientTape() as tape:
            x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self(
                [x, y], training=True)

            Kern = 1.0 - Degree_matrix(image_in_patches(x, 20),
                                       image_in_patches(y, 20))
            kernels_loss = self.kernels_lambda * self.loss_object(Kern, ztz)
            l2_loss_k = sum(self._enc_x.losses) + sum(self._enc_y.losses)
            targets_k = (self._enc_x.trainable_variables +
                         self._enc_y.trainable_variables)
            gradients_k = tape.gradient(kernels_loss + l2_loss_k, targets_k)
            if self.clipnorm is not None:
                gradients_k, _ = tf.clip_by_global_norm(
                    gradients_k, self.clipnorm)

            self._optimizer_k.apply_gradients(zip(gradients_k, targets_k))

        with tf.GradientTape() as tape:
            x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self(
                [x, y], training=True)
            l2_loss = (sum(self._enc_x.losses) + sum(self._enc_y.losses) +
                       sum(self._dec_x.losses) + sum(self._dec_y.losses))
            cycle_x_loss = self.cycle_lambda * self.loss_object(x, x_dot)
            cross_x_loss = self.cross_lambda * self.loss_object(y, y_hat, clw)
            recon_x_loss = self.recon_lambda * self.loss_object(x, x_tilde)
            cycle_y_loss = self.cycle_lambda * self.loss_object(y, y_dot)
            cross_y_loss = self.cross_lambda * self.loss_object(x, x_hat, clw)
            recon_y_loss = self.recon_lambda * self.loss_object(y, y_tilde)

            total_loss = (cycle_x_loss + cross_x_loss + recon_x_loss +
                          cycle_y_loss + cross_y_loss + recon_y_loss + l2_loss)

            targets_all = (self._enc_x.trainable_variables +
                           self._enc_y.trainable_variables +
                           self._dec_x.trainable_variables +
                           self._dec_y.trainable_variables)

            gradients_all = tape.gradient(total_loss, targets_all)

            if self.clipnorm is not None:
                gradients_all, _ = tf.clip_by_global_norm(
                    gradients_all, self.clipnorm)
            self._optimizer_all.apply_gradients(zip(gradients_all,
                                                    targets_all))
        self.train_metrics["cycle_x"].update_state(cycle_x_loss)
        self.train_metrics["cross_x"].update_state(cross_x_loss)
        self.train_metrics["recon_x"].update_state(recon_x_loss)
        self.train_metrics["cycle_y"].update_state(cycle_y_loss)
        self.train_metrics["cross_y"].update_state(cross_y_loss)
        self.train_metrics["recon_y"].update_state(recon_y_loss)
        self.train_metrics["krnls"].update_state(kernels_loss)
        self.train_metrics["l2"].update_state(l2_loss)
        self.train_metrics["total"].update_state(total_loss)
Пример #2
0
def fetch_CGAN(name, **kwargs):
    """
        Input:
            name - dataset name, should be in DATASETS
            kwargs - config {key: value} pairs.
                     Key should be in DATASET_DEFAULT_CONFIG
        Output:
            training_data - tf.data.Dataset with (x, y, prior)
                            shapes like (inf, patch_size, patch_size, ?)
            evaluation_data - tf.data.Dataset with (x, y, change_map)
                              shapes (1, h, w, ?)
            channels - tuple (c_x, c_y), number of channels for domains x and y
    """
    ps = kwargs.get("patch_size")
    y_im, x_im, target_cm = DATASETS[name](prepare_data[name])
    if not tf.config.list_physical_devices("GPU"):
        dataset = [
            tf.image.central_crop(tensor, 0.1)
            for tensor in [x_im, y_im, target_cm]
        ]
    else:
        dataset = [x_im, y_im, target_cm]
    chs = [tensor.shape[-1] for tensor in dataset]
    dataset = [remove_borders(tensor, ps) for tensor in dataset]
    dataset = [tf.expand_dims(tensor, 0) for tensor in dataset]
    evaluation_data = tf.data.Dataset.from_tensor_slices(tuple(dataset))
    dataset = [image_in_patches(tensor, ps) for tensor in dataset]
    tot_patches = dataset[0].shape[0]
    return dataset[0], dataset[1], evaluation_data, (chs[0],
                                                     chs[1]), tot_patches
Пример #3
0
 def gen(self, input, gen_chs):
     """ Wraps encoder call for TensorBoard printing and image save """
     h, w, ch = input.shape[1], input.shape[2], input.shape[-1]
     input = image_in_patches(input, self.ps)
     nb = input.shape[0] // 200 + (input.shape[0] % 200)
     tmp = tf.zeros([1, self.ps, self.ps, gen_chs], dtype=tf.float32)
     for i in range(nb):
         start = i * 200
         stop = tf.reduce_min([input.shape[0], start + 200])
         tmp = tf.concat([tmp, self._gen(input[start:stop])], 0)
     tmp = tf.nn.depth_to_space(
         tf.reshape(tmp[1:], [1, h // self.ps, w // self.ps, -1]), self.ps)
     return tf.reshape(tmp, [1, h, w, -1])
Пример #4
0
 def approx(self, input):
     """ Wraps encoder call for TensorBoard printing and image save """
     ch = input.shape[-1]
     input = image_in_patches(input, self.ps)
     h, w = input.shape[1], input.shape[2]
     input = tf.reshape(input, [-1, self.ps, self.ps, ch])
     nb = input.shape[0] // 200 + (input.shape[0] % 200)
     tmp = tf.zeros([1, self.ps, self.ps, ch], dtype=tf.float32)
     for i in range(nb):
         start = i * 200
         stop = tf.reduce_min([input.shape[0], start + 200])
         tmp = tf.concat([tmp, self._approx(input[start:stop])], 0)
     tmp = tf.convert_to_tensor(tmp[1:])
     return tf.nn.depth_to_space(tf.reshape(tmp, [1, h, w, -1]), self.ps)
Пример #5
0
    def __call__(self, inputs, training=False):
        x, y = inputs
        tf.debugging.Assert(tf.rank(x) == 4, [x.shape])
        tf.debugging.Assert(tf.rank(y) == 4, [y.shape])

        if training:
            x_code, y_code = self._enc_x(x, training), self._enc_y(y, training)
            x_hat, y_hat = self._dec_x(y_code, training), self._dec_y(
                x_code, training)
            x_dot, y_dot = (
                self._dec_x(self._enc_y(y_hat, training), training),
                self._dec_y(self._enc_x(x_hat, training), training),
            )
            x_tilde, y_tilde = (
                self._dec_x(x_code, training),
                self._dec_y(y_code, training),
            )
            zx_t_zy = ztz(image_in_patches(x_code, 20),
                          image_in_patches(y_code, 20))
            retval = [x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, zx_t_zy]

        else:
            x_code, y_code = self.enc_x(x, name="x_code"), self.enc_y(
                y, name="y_code")
            x_tilde, y_tilde = (
                self.dec_x(x_code, name="x_tilde"),
                self.dec_y(y_code, name="y_tilde"),
            )
            x_hat, y_hat = (
                self.dec_x(y_code, name="x_hat"),
                self.dec_y(x_code, name="y_hat"),
            )
            difference_img = self._difference_img(x_tilde, y_tilde, x_hat,
                                                  y_hat)
            retval = difference_img

        return retval
Пример #6
0
def test(DATASET="Texas", CONFIG=None):
    """
    1. Fetch data (x, y, change_map)
    2. Compute/estimate A_x and A_y (for patches)
    3. Compute change_prior
    4. Define dataset with (x, A_x, y, A_y, p). Choose patch size compatible
       with affinity computations.
    5. Train CrossCyclicImageTransformer unsupervised
        a. Evaluate the image transformations in some way?
    6. Evaluate the change detection scheme
        a. change_map = threshold [(x - f_y(y))/2 + (y - f_x(x))/2]
    """
    if CONFIG is None:
        CONFIG = get_config_CGAN(DATASET)

    bs = CONFIG["batch_size"]
    ps = CONFIG["patch_size"]
    print(f"Loading {DATASET} data")

    x, y, EVALUATE, (C_X, C_Y), tot_patches = datasets.fetch_CGAN(
        DATASET, **CONFIG)
    Pu = tf.ones(x.shape[0], dtype=tf.float32)
    batches = tot_patches // bs + (tot_patches % bs != 0)
    CONFIG.update({"tot_patches": tot_patches, "batches": batches})
    if tf.config.list_physical_devices("GPU") and not CONFIG["debug"]:
        TRANSLATION_SPEC = {
            "Generator": {
                "shapes": [ps, C_Y],
                "filter_spec": [25, 100, 500, 100, C_X],
            },
            "Approximator": {
                "shapes": [ps, C_X],
                "filter_spec": [25, 100, 500, 100, C_X],
            },
            "Discriminator": {
                "shapes": [ps, C_X],
                "filter_spec": [25, 100, 200, 50, 1],
            },
        }
    else:
        TRANSLATION_SPEC = {
            "Generator": {
                "shapes": [ps, C_Y],
                "filter_spec": [25, C_X],
            },
            "Approximator": {
                "shapes": [ps, C_X],
                "filter_spec": [25, C_X],
            },
            "Discriminator": {
                "shapes": [ps, C_X],
                "filter_spec": [25, 1]
            },
        }
    print("Change Detector Init")
    cd = CGAN(TRANSLATION_SPEC, **CONFIG)
    print("Training")
    training_time = 0
    for epoch in trange(CONFIG["epochs"]):
        CONFIG.update(epochs=1)
        dataset = [x, y, Pu]
        TRAIN = tf.data.Dataset.from_tensor_slices(tuple(dataset))
        TRAIN = TRAIN.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        tr_time, _ = cd.train(TRAIN, evaluation_dataset=EVALUATE, **CONFIG)
        if epoch > 10:
            for x_im, y_im, _ in EVALUATE.batch(1):
                Pu = cd._change_map(cd([x_im, y_im]))
            Pu = image_in_patches(Pu, ps)
            Pu = tf.reshape(Pu, [-1, Pu.shape[-1]])
            Pu = tf.round(
                tf.reduce_mean(tf.cast(Pu, dtype=tf.float32), axis=-1))

    cd.final_evaluate(EVALUATE, **CONFIG)
    final_kappa = cd.metrics_history["cohens kappa"][-1]
    timestamp = cd.timestamp
    epoch = cd.epoch.numpy()
    return final_kappa, epoch, training_time, timestamp