def _train_step(self, x, y, clw): """ Input: x - tensor of shape (bs, ps_h, ps_w, c_x) y - tensor of shape (bs, ps_h, ps_w, c_y) clw - cross_loss_weight, tensor of shape (bs, ps_h, ps_w, 1) """ with tf.GradientTape() as tape: x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self( [x, y], training=True) Kern = 1.0 - Degree_matrix(image_in_patches(x, 20), image_in_patches(y, 20)) kernels_loss = self.kernels_lambda * self.loss_object(Kern, ztz) l2_loss_k = sum(self._enc_x.losses) + sum(self._enc_y.losses) targets_k = (self._enc_x.trainable_variables + self._enc_y.trainable_variables) gradients_k = tape.gradient(kernels_loss + l2_loss_k, targets_k) if self.clipnorm is not None: gradients_k, _ = tf.clip_by_global_norm( gradients_k, self.clipnorm) self._optimizer_k.apply_gradients(zip(gradients_k, targets_k)) with tf.GradientTape() as tape: x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, ztz = self( [x, y], training=True) l2_loss = (sum(self._enc_x.losses) + sum(self._enc_y.losses) + sum(self._dec_x.losses) + sum(self._dec_y.losses)) cycle_x_loss = self.cycle_lambda * self.loss_object(x, x_dot) cross_x_loss = self.cross_lambda * self.loss_object(y, y_hat, clw) recon_x_loss = self.recon_lambda * self.loss_object(x, x_tilde) cycle_y_loss = self.cycle_lambda * self.loss_object(y, y_dot) cross_y_loss = self.cross_lambda * self.loss_object(x, x_hat, clw) recon_y_loss = self.recon_lambda * self.loss_object(y, y_tilde) total_loss = (cycle_x_loss + cross_x_loss + recon_x_loss + cycle_y_loss + cross_y_loss + recon_y_loss + l2_loss) targets_all = (self._enc_x.trainable_variables + self._enc_y.trainable_variables + self._dec_x.trainable_variables + self._dec_y.trainable_variables) gradients_all = tape.gradient(total_loss, targets_all) if self.clipnorm is not None: gradients_all, _ = tf.clip_by_global_norm( gradients_all, self.clipnorm) self._optimizer_all.apply_gradients(zip(gradients_all, targets_all)) self.train_metrics["cycle_x"].update_state(cycle_x_loss) self.train_metrics["cross_x"].update_state(cross_x_loss) self.train_metrics["recon_x"].update_state(recon_x_loss) self.train_metrics["cycle_y"].update_state(cycle_y_loss) self.train_metrics["cross_y"].update_state(cross_y_loss) self.train_metrics["recon_y"].update_state(recon_y_loss) self.train_metrics["krnls"].update_state(kernels_loss) self.train_metrics["l2"].update_state(l2_loss) self.train_metrics["total"].update_state(total_loss)
def fetch_CGAN(name, **kwargs): """ Input: name - dataset name, should be in DATASETS kwargs - config {key: value} pairs. Key should be in DATASET_DEFAULT_CONFIG Output: training_data - tf.data.Dataset with (x, y, prior) shapes like (inf, patch_size, patch_size, ?) evaluation_data - tf.data.Dataset with (x, y, change_map) shapes (1, h, w, ?) channels - tuple (c_x, c_y), number of channels for domains x and y """ ps = kwargs.get("patch_size") y_im, x_im, target_cm = DATASETS[name](prepare_data[name]) if not tf.config.list_physical_devices("GPU"): dataset = [ tf.image.central_crop(tensor, 0.1) for tensor in [x_im, y_im, target_cm] ] else: dataset = [x_im, y_im, target_cm] chs = [tensor.shape[-1] for tensor in dataset] dataset = [remove_borders(tensor, ps) for tensor in dataset] dataset = [tf.expand_dims(tensor, 0) for tensor in dataset] evaluation_data = tf.data.Dataset.from_tensor_slices(tuple(dataset)) dataset = [image_in_patches(tensor, ps) for tensor in dataset] tot_patches = dataset[0].shape[0] return dataset[0], dataset[1], evaluation_data, (chs[0], chs[1]), tot_patches
def gen(self, input, gen_chs): """ Wraps encoder call for TensorBoard printing and image save """ h, w, ch = input.shape[1], input.shape[2], input.shape[-1] input = image_in_patches(input, self.ps) nb = input.shape[0] // 200 + (input.shape[0] % 200) tmp = tf.zeros([1, self.ps, self.ps, gen_chs], dtype=tf.float32) for i in range(nb): start = i * 200 stop = tf.reduce_min([input.shape[0], start + 200]) tmp = tf.concat([tmp, self._gen(input[start:stop])], 0) tmp = tf.nn.depth_to_space( tf.reshape(tmp[1:], [1, h // self.ps, w // self.ps, -1]), self.ps) return tf.reshape(tmp, [1, h, w, -1])
def approx(self, input): """ Wraps encoder call for TensorBoard printing and image save """ ch = input.shape[-1] input = image_in_patches(input, self.ps) h, w = input.shape[1], input.shape[2] input = tf.reshape(input, [-1, self.ps, self.ps, ch]) nb = input.shape[0] // 200 + (input.shape[0] % 200) tmp = tf.zeros([1, self.ps, self.ps, ch], dtype=tf.float32) for i in range(nb): start = i * 200 stop = tf.reduce_min([input.shape[0], start + 200]) tmp = tf.concat([tmp, self._approx(input[start:stop])], 0) tmp = tf.convert_to_tensor(tmp[1:]) return tf.nn.depth_to_space(tf.reshape(tmp, [1, h, w, -1]), self.ps)
def __call__(self, inputs, training=False): x, y = inputs tf.debugging.Assert(tf.rank(x) == 4, [x.shape]) tf.debugging.Assert(tf.rank(y) == 4, [y.shape]) if training: x_code, y_code = self._enc_x(x, training), self._enc_y(y, training) x_hat, y_hat = self._dec_x(y_code, training), self._dec_y( x_code, training) x_dot, y_dot = ( self._dec_x(self._enc_y(y_hat, training), training), self._dec_y(self._enc_x(x_hat, training), training), ) x_tilde, y_tilde = ( self._dec_x(x_code, training), self._dec_y(y_code, training), ) zx_t_zy = ztz(image_in_patches(x_code, 20), image_in_patches(y_code, 20)) retval = [x_hat, y_hat, x_dot, y_dot, x_tilde, y_tilde, zx_t_zy] else: x_code, y_code = self.enc_x(x, name="x_code"), self.enc_y( y, name="y_code") x_tilde, y_tilde = ( self.dec_x(x_code, name="x_tilde"), self.dec_y(y_code, name="y_tilde"), ) x_hat, y_hat = ( self.dec_x(y_code, name="x_hat"), self.dec_y(x_code, name="y_hat"), ) difference_img = self._difference_img(x_tilde, y_tilde, x_hat, y_hat) retval = difference_img return retval
def test(DATASET="Texas", CONFIG=None): """ 1. Fetch data (x, y, change_map) 2. Compute/estimate A_x and A_y (for patches) 3. Compute change_prior 4. Define dataset with (x, A_x, y, A_y, p). Choose patch size compatible with affinity computations. 5. Train CrossCyclicImageTransformer unsupervised a. Evaluate the image transformations in some way? 6. Evaluate the change detection scheme a. change_map = threshold [(x - f_y(y))/2 + (y - f_x(x))/2] """ if CONFIG is None: CONFIG = get_config_CGAN(DATASET) bs = CONFIG["batch_size"] ps = CONFIG["patch_size"] print(f"Loading {DATASET} data") x, y, EVALUATE, (C_X, C_Y), tot_patches = datasets.fetch_CGAN( DATASET, **CONFIG) Pu = tf.ones(x.shape[0], dtype=tf.float32) batches = tot_patches // bs + (tot_patches % bs != 0) CONFIG.update({"tot_patches": tot_patches, "batches": batches}) if tf.config.list_physical_devices("GPU") and not CONFIG["debug"]: TRANSLATION_SPEC = { "Generator": { "shapes": [ps, C_Y], "filter_spec": [25, 100, 500, 100, C_X], }, "Approximator": { "shapes": [ps, C_X], "filter_spec": [25, 100, 500, 100, C_X], }, "Discriminator": { "shapes": [ps, C_X], "filter_spec": [25, 100, 200, 50, 1], }, } else: TRANSLATION_SPEC = { "Generator": { "shapes": [ps, C_Y], "filter_spec": [25, C_X], }, "Approximator": { "shapes": [ps, C_X], "filter_spec": [25, C_X], }, "Discriminator": { "shapes": [ps, C_X], "filter_spec": [25, 1] }, } print("Change Detector Init") cd = CGAN(TRANSLATION_SPEC, **CONFIG) print("Training") training_time = 0 for epoch in trange(CONFIG["epochs"]): CONFIG.update(epochs=1) dataset = [x, y, Pu] TRAIN = tf.data.Dataset.from_tensor_slices(tuple(dataset)) TRAIN = TRAIN.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) tr_time, _ = cd.train(TRAIN, evaluation_dataset=EVALUATE, **CONFIG) if epoch > 10: for x_im, y_im, _ in EVALUATE.batch(1): Pu = cd._change_map(cd([x_im, y_im])) Pu = image_in_patches(Pu, ps) Pu = tf.reshape(Pu, [-1, Pu.shape[-1]]) Pu = tf.round( tf.reduce_mean(tf.cast(Pu, dtype=tf.float32), axis=-1)) cd.final_evaluate(EVALUATE, **CONFIG) final_kappa = cd.metrics_history["cohens kappa"][-1] timestamp = cd.timestamp epoch = cd.epoch.numpy() return final_kappa, epoch, training_time, timestamp