class KAST(tf.keras.Model): def __init__(self, coef_memory=0.1, dropout_seq=0.9): super(KAST, self).__init__() self.kernel = 13 self.dropout_seq = dropout_seq self.transformation = Transformation(trainable=False) self.resnet = ResNet() self.rkn = RKNModel() self.memory = Memory(unit=200, kernel=self.kernel) self.corr_cost = tfa.layers.CorrelationCost( kernel_size=1, max_displacement=self.kernel // 2, stride_1=1, stride_2=1, pad=self.kernel // 2, data_format="channels_last") self.corr_cost_stride = tfa.layers.CorrelationCost( kernel_size=1, max_displacement=(self.kernel // 2) * 2, stride_1=1, stride_2=2, pad=(self.kernel // 2) * 2, data_format="channels_last") #self.memory = tf.keras.Sequential() #self.memory.add(tf.keras.layers.Input(input_shape=((None, None, 256)), batch_input_shape=[4])) #self.memory.add(tf.keras.layers.RNN(self.memory_cell, stateful=True)) self.coef_memory = coef_memory self.description = 'KAST' self.mem_write = True self.mem0 = None self.mem5 = None self.k0 = None self.v0 = None self.last_v = None def call(self, inputs, **kwargs): # inputs: [(bs, T, H, W, 3), (bs, T, h, w, 3)] keep = kwargs['keep'] if 'keep' in kwargs else False bs = inputs[1].shape[0] self.memory.batch_shape = bs seq_size = inputs[1].shape[1] self.transformation.batch_shape = bs self.transformation.seq_size = seq_size H = inputs[0].shape[2] W = inputs[0].shape[3] C = inputs[0].shape[4] h = inputs[1].shape[2] w = inputs[1].shape[3] cv = inputs[1].shape[4] self.memory.hw_shape = h * w output_v = [] i_raw, v = tf.nest.flatten(inputs) #print("i.shape: ", i.shape) #print("v.shape: ", v.shape) with tf.name_scope('Transformation'): i_drop, seq_mask = self.transformation(i_raw, **kwargs) with tf.name_scope('ResNet'): k = tf.reshape(self.resnet(tf.reshape(i_drop, [-1, H, W, C])), [bs, seq_size, h, w, 256]) # (bs, T, h, w, 256) ck = k.shape[4] #with tf.name_scope('Rkn'): # score = self.rkn((k, tf.reshape(seq_mask, [bs, seq_size, 1]))) if self.k0 is None: self.k0 = k[:, 0] if self.v0 is None: self.v0 = v[:, 0] if self.last_v is None: self.last_v = v[:, 0] self.memory.get_init_state(bs, cv) self.memory.call_init((tf.reshape(k[:, 0], [bs, h * w, ck]), tf.reshape(self.last_v, [bs, h * w, cv]))) all_m_kv = [] all_previous_v = [self.last_v] ground_truth = [tf.reshape(v[:, 0], [-1, 1, h, w, cv])] for i in range(1, seq_size): if i < 7 and self.mem_write: with tf.name_scope('Memory'): m_kv = self.memory.call( (tf.reshape(k[:, i - 1], [bs, h * w, ck]), tf.reshape(all_previous_v[i - 1], [bs, h * w, cv]))) all_m_kv.append(m_kv) corr_prev_one = self.corr_cost([k[:, i], k[:, i - 1] ]) * 256.0 # (bs, hw, patch) corr_prev = tf.reshape(corr_prev_one, [bs, h * w, self.kernel**2]) patch_v1 = tf.image.extract_patches( images=tf.reshape(all_previous_v[i - 1], [-1, h, w, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding="SAME") patch_v = tf.reshape(patch_v1, [bs, h * w, self.kernel**2, cv]) """ if self.mem0 is None: self.mem0 = all_m_kv[0] m_k0, m_v0 = tf.nest.flatten(self.mem0) ref_transpose = tf.transpose(m_k0, [0, 2, 1]) # (bs, k, m) inner_product = tf.reshape(k[:, i], [bs, h*w, ck]) @ ref_transpose # (bs, hw, k) @ (bs, k, m) = (bs, hw, m) idx_top0 = tf.argmax(inner_product, axis=-1) top_k0 = tf.gather(m_k0, idx_top0, batch_dims=1, axis=1) # (bs, hw, k) top_v0 = tf.gather(m_v0, idx_top0, batch_dims=1, axis=1) # (bs, hw, v) top_mk = tf.reshape(top_k0, [bs, h*w, 1, ck]) top_mv = tf.reshape(top_v0, [bs, h*w, 1, cv]) if i >= 3: #corr_prev_three = self.corr_cost_stride([k[:, i], k[:, i-3]])*2.0 # (bs, hw, kernel**2) corr_prev_three = self.corr_cost([k[:, i], k[:, i-3]])*256.0 # (bs, hw, kernel**2) corr_prev_three = tf.reshape(corr_prev_three, [bs, h*w, self.kernel ** 2]) corr_prev = tf.concat([corr_prev, corr_prev_three], axis=-1) #patch_v3 = tf.image.extract_patches(images=tf.reshape(all_previous_v[i-3], [-1, 64, 64, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 2, 2, 1], padding="SAME") patch_v3 = tf.image.extract_patches(images=tf.reshape(all_previous_v[i-3], [-1, h, w, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding="SAME") patch_v3 = tf.reshape(patch_v3, [bs, h * w, self.kernel ** 2, cv]) patch_v = tf.concat([patch_v, patch_v3], axis=-2) #if i >= 5: # # corr_prev_three = self.corr_cost_stride([k[:, i], k[:, i-3]])*2.0 # (bs, hw, kernel**2) # corr_prev_three = self.corr_cost([k[:, i], k[:, i - 5]]) * 256.0 # (bs, hw, kernel**2) # corr_prev_three = tf.reshape(corr_prev_three, [bs, h * w, self.kernel ** 2]) # corr_prev = tf.concat([corr_prev, corr_prev_three], axis=-1) # # patch_v3 = tf.image.extract_patches(images=tf.reshape(all_previous_v[i-3], [-1, 64, 64, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 2, 2, 1], padding="SAME") # patch_v3 = tf.image.extract_patches(images=tf.reshape(all_previous_v[i - 5], [-1, h, w, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding="SAME") # patch_v3 = tf.reshape(patch_v3, [bs, h * w, self.kernel ** 2, cv]) # patch_v = tf.concat([patch_v, patch_v3], axis=-2) if i >= 5: corr_prev_five = self.corr_cost_stride([k[:, i], self.k0])*256.0 # (bs, hw, kernel**2) #corr_prev_five = self.corr_cost([k[:, i], k[:, i-5]])*2.0 # (bs, hw, kernel**2) corr_prev_five = tf.reshape(corr_prev_five, [bs, h*w, self.kernel ** 2]) corr_prev = tf.concat([corr_prev, corr_prev_five], axis=-1) patch_v5 = tf.image.extract_patches(images=tf.reshape(self.v0, [-1, h, w, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 2, 2, 1], padding="SAME") #patch_v5 = tf.image.extract_patches(images=tf.reshape(all_previous_v[i-5], [-1, 64, 64, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding="SAME") patch_v5 = tf.reshape(patch_v5, [bs, h * w, self.kernel ** 2, cv]) patch_v = tf.concat([patch_v, patch_v5], axis=-2) if i >= 6: if self.mem5 is None: self.mem5 = all_m_kv[5] m_k5, m_v5 = tf.nest.flatten(self.mem5) #m_k5, m_v5 = tf.nest.flatten(all_m_kv[5]) ref_transpose = tf.transpose(m_k5, [0, 2, 1]) # (bs, k, m) inner_product = tf.reshape(k[:, i], [bs, h*w, ck]) @ ref_transpose # (bs, hw, k) @ (bs, k, m) = (bs, hw, m) idx_top5 = tf.argmax(inner_product, axis=-1) top_k5 = tf.gather(m_k5, idx_top5, batch_dims=1, axis=1) # (bs, hw, 1, k) top_v5 = tf.gather(m_v5, idx_top5, batch_dims=1, axis=1) # (bs, hw, 1, v) top_mk5 = tf.reshape(top_k5, [bs, h * w, 1, ck]) top_mv5 = tf.reshape(top_v5, [bs, h * w, 1, cv]) top_mk = tf.concat([top_mk, top_mk5], axis=-2) top_mv = tf.concat([top_mv, top_mv5], axis=-2) # top_mk: (bs, hw, nb_memory, k) # top_mv: (bs, hw, nb_memory, v) # corr_prev: (bs, hw, nb_patches * kernel**2) # patch_v: (bs, hw, nb_patches * kernel**2, v) ref_transpose = tf.transpose(top_mk, [0, 1, 3, 2]) # (bs, hw, k, nb_memory) corr_memory = tf.squeeze(tf.reshape(k[:, i], [bs, h*w, 1, ck]) @ ref_transpose, axis=[2]) # (bs, hw, 1, k) @ (bs, hw, k, nb_memory) = (bs, hw, 1, nb_memory) all_corr = tf.concat([corr_prev, corr_memory], axis=-1) # (bs, hw, nb_memory+nb_patches*kernel**2) all_v = tf.concat([patch_v, top_mv], axis=-2) # (bs, hw, nb_memory+nb_patches*kernel**2, v) all_sim = tf.expand_dims(tf.nn.softmax(all_corr, axis=-1), axis=-2) # (bs, hw, 1, nb_memory+nb_patches*kernel**2) output_v_i = all_sim @ all_v # (bs, hw, 1, nb_memory+nb_patches*kernel**2) @ (bs, hw, nb_memory+nb_patches*kernel**2, v) = (bs, hw, 1, v) """ corr_sim = tf.expand_dims(tf.nn.softmax(corr_prev, -1), -2) output_v_i = corr_sim @ patch_v #output_v_i = tf.one_hot(tf.argmax(output_v_i, -1), 9) previous_v = tf.where(tf.reshape(seq_mask[:, i], [bs, 1, 1, 1]), v[:, i], tf.reshape(output_v_i, [-1, h, w, cv])) all_previous_v.append(previous_v) output_v_i = tf.reshape(output_v_i, [-1, 1, h, w, cv]) output_v.append(output_v_i) ground_truth_i = tf.reshape(v[:, i], [-1, 1, h, w, cv]) ground_truth.append(ground_truth_i) # print("output_v len: ", len(output_v)) # print("output_v[0].shape: ", output_v[0].shape) # print("ground_truth len: ", len(ground_truth)) # print("ground_truth[0].shape: ", ground_truth[0].shape) # self.memory.get_initial_state() self.mem_write = False self.last_v = output_v_i if not keep: self.reset_mem() return tf.concat(output_v, 1), tf.concat(ground_truth, 1), i_drop def call_Patch_Memory(self, inputs, **kwargs): # inputs: [(bs, T, H, W, 3), (bs, T, h, w, 3)] bs = inputs[1].shape[0] seq_size = inputs[1].shape[1] H = inputs[0].shape[2] W = inputs[0].shape[3] C = inputs[0].shape[4] h = inputs[1].shape[2] w = inputs[1].shape[3] cv = inputs[1].shape[4] output_v = [] ground_truth = [] i_raw, v = tf.nest.flatten(inputs) #print("i.shape: ", i.shape) #print("v.shape: ", v.shape) with tf.name_scope('Transformation'): i_drop, seq_mask = self.transformation(i_raw, **kwargs) with tf.name_scope('ResNet'): k = tf.reshape(self.resnet(tf.reshape(i_drop, [-1, H, W, C])), [bs, seq_size, h, w, 256]) # (bs, T, h, w, 256) ck = k.shape[4] with tf.name_scope('Rkn'): score = self.rkn((k, tf.reshape(seq_mask, [bs, seq_size, 1]))) previous_v = v[:, 0] self.memory.get_init_state(bs, cv) self.memory.call_init((tf.reshape( k[:, 0], [bs, h * w, ck]), tf.reshape(previous_v, [bs, h * w, cv]), tf.reshape(score[:, 0], [bs, h * w])), bs) for i in range(seq_size - 1): with tf.name_scope('Memory'): m_kv = self.memory.call((tf.reshape(k[:, i], [bs, h * w, ck]), tf.reshape(previous_v, [bs, h * w, cv]), tf.reshape(score[:, i], [bs, h * w]))) m_k, m_v = tf.nest.flatten( m_kv) # (bs, m, kernel**2 * k), (bs, m, kernel**2 * v) #km_k = tf.concat([tf.reshape(k[:, i], [-1, h*w, ck]), m_k], 1) # (bs, h*w + m, ck) #vm_v = tf.concat([tf.reshape(previous_v, [-1, h*w, cv]), m_v], 1) # (bs, h*w + m, cv) with tf.name_scope('Similarity_matrix'): output_v_i = self._get_output_patch( m_k, tf.reshape(k[:, i + 1], [-1, h * w, ck]), m_v) # (bs, nb_patch, h*w+m) #with tf.name_scope('Similarity_K'): # similarity_k = self._get_affinity_matrix(tf.reshape(k[:, i], [-1, h*w, ck]), tf.reshape(k[:, i+1], [-1, h*w, ck])) # (bs, h*w, h*w) #with tf.name_scope('Similarity_M'): # similarity_m = self._get_affinity_matrix(m_k, tf.reshape(k[:, i+1], [-1, h * w, ck])) # (bs, h*w, m) #reconstruction_k = similarity_k @ tf.reshape(previous_v, [-1, h * w, cv]) # (bs, h*w, v) #reconstruction_m = similarity_m @ m_v #output_v_i = similarity @ vm_v #output_v_i = (1 - self.coef_memory) * reconstruction_k + self.coef_memory * reconstruction_m previous_v = tf.where( tf.reshape(seq_mask[:, i + 1], [bs, 1, 1, 1]), v[:, i], tf.reshape(output_v_i, [-1, h, w, cv])) output_v_i = tf.reshape(output_v_i, [-1, 1, h, w, cv]) output_v.append(output_v_i) ground_truth_i = tf.reshape(v[:, i + 1], [-1, 1, h, w, cv]) ground_truth.append(ground_truth_i) #print("output_v len: ", len(output_v)) #print("output_v[0].shape: ", output_v[0].shape) #print("ground_truth len: ", len(ground_truth)) #print("ground_truth[0].shape: ", ground_truth[0].shape) #self.memory.get_initial_state() return tf.concat(output_v, 1), tf.concat(ground_truth, 1), i_drop def call_ResNet(self, inputs, **kwargs): # inputs: [(bs, T, H, W, 3), (bs, T, h, w, 3)] bs = inputs[1].shape[0] seq_size = inputs[1].shape[1] H = inputs[0].shape[2] W = inputs[0].shape[3] C = inputs[0].shape[4] h = inputs[1].shape[2] w = inputs[1].shape[3] cv = inputs[1].shape[4] output_v = [] ground_truth = [] i_raw, v = tf.nest.flatten(inputs) # print("i.shape: ", i.shape) # print("v.shape: ", v.shape) with tf.name_scope('Transformation'): i_drop = self.transformation(i_raw, **kwargs) with tf.name_scope('ResNet'): k = tf.reshape(self.resnet(tf.reshape(i_drop, [-1, H, W, C])), [bs, seq_size, h, w, 256]) # (bs, T, h, w, 256) ck = k.shape[4] with tf.name_scope('Similarity_K'): similarity_k = self._get_affinity_matrix( tf.reshape(k[:, 0], [-1, h * w, ck]), tf.reshape(k[:, 1], [-1, h * w, ck])) # (bs, h*w, h*w) reconstruction_k = tf.reshape( similarity_k @ tf.reshape(v[:, 0], [-1, h * w, cv]), [-1, h, w, cv]) # (bs, h*w, v) ground_truth = v[:, 1] return reconstruction_k, ground_truth def call_ResNet_Local(self, inputs, **kwargs): # inputs: [(bs, T, H, W, 3), (bs, T, h, w, 3)] bs = inputs[1].shape[0] seq_size = inputs[1].shape[1] H = inputs[0].shape[2] W = inputs[0].shape[3] C = inputs[0].shape[4] h = inputs[1].shape[2] w = inputs[1].shape[3] cv = inputs[1].shape[4] i_raw, v = tf.nest.flatten(inputs) with tf.name_scope('Transformation'): i_drop, _ = self.transformation(i_raw, **kwargs) with tf.name_scope('ResNet'): k = tf.reshape(self.resnet(tf.reshape(i_drop, [-1, H, W, C])), [bs, seq_size, h, w, 256]) # (bs, T, h, w, 256) corr_cost = tfa.layers.CorrelationCost(kernel_size=1, max_displacement=self.kernel // 2, stride_1=1, stride_2=1, pad=self.kernel // 2, data_format="channels_last") similarity_k = corr_cost([k[:, 1], k[:, 0]]) # (bs, hw, patch) similarity_k = tf.reshape(similarity_k, [bs, h * w, self.kernel * self.kernel]) similarity_k = tf.nn.softmax(similarity_k, axis=-1) similarity_k = tf.reshape(similarity_k, [bs, h * w, 1, self.kernel * self.kernel]) v_patch = tf.image.extract_patches( tf.reshape(v[:, 0], [bs, h, w, cv]), sizes=[1, self.kernel, self.kernel, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding="SAME") reconstruction_v = similarity_k @ tf.reshape( v_patch, [bs, h * w, self.kernel * self.kernel, cv]) # (bs, h*w, v) reconstruction_v = tf.reshape(reconstruction_v, [bs, h, w, cv]) ground_truth = v[:, 1] return reconstruction_v, ground_truth, v[:, 0] def call_RKN(self, inputs, **kwargs): # inputs: [(bs, T, H, W, 3), (bs, T, h, w, 3)] bs = inputs[1].shape[0] seq_size = inputs[1].shape[1] H = inputs[0].shape[2] W = inputs[0].shape[3] C = inputs[0].shape[4] h = inputs[1].shape[2] w = inputs[1].shape[3] cv = inputs[1].shape[4] i_raw, v = tf.nest.flatten(inputs) #with tf.name_scope('Transformation'): # i_drop = self.transformation(i_raw, **kwargs) with tf.name_scope('ResNet'): k = tf.reshape(self.resnet(tf.reshape(i_raw, [-1, H, W, C])), [bs, seq_size, h, w, 256]) # (bs, T, h, w, 256) mask = np.random.binomial(1, 0.2, [bs, seq_size]) mask[:, 0] = 1 mask = tf.cast(mask, tf.bool) mask = tf.reshape(mask, [bs, seq_size, 1]) with tf.name_scope('Rkn'): rkn_k = self.rkn((k, mask)) return rkn_k, k def call_Score(self, inputs, **kwargs): # inputs: [(bs, T, H, W, 3), (bs, T, h, w, 3)] bs = inputs[1].shape[0] seq_size = inputs[1].shape[1] H = inputs[0].shape[2] W = inputs[0].shape[3] C = inputs[0].shape[4] h = inputs[1].shape[2] w = inputs[1].shape[3] cv = inputs[1].shape[4] i_raw, v = tf.nest.flatten(inputs) # with tf.name_scope('Transformation'): # i_drop = self.transformation(i_raw, **kwargs) with tf.name_scope('ResNet'): k = tf.reshape(self.resnet(tf.reshape(i_raw, [-1, H, W, C])), [bs, seq_size, h, w, 256]) # (bs, T, h, w, 256) mask = np.random.binomial(1, 0.9, [bs, seq_size]) mask[:, 0] = 1 mask = tf.cast(mask, tf.bool) mask = tf.reshape(mask, [bs, seq_size, 1]) with tf.name_scope('Rkn'): rkn_score = self.rkn((k, mask)) mask_score = tf.concat([ tf.ones([bs, 1, h * w, 1]), tf.zeros([bs, seq_size - 1, h * w, 1]) ], 1) k = tf.reshape(k, [bs, seq_size, h * w, 256]) v = tf.reshape(v, [bs, seq_size, h * w, cv]) rkn_score = tf.reshape(rkn_score, [bs, seq_size, h * w, 1]) * mask_score with tf.name_scope("Memory"): mem = self.memory((k, v, rkn_score)) m_k, m_v, m_u, m_rkn_score = tf.nest.flatten(mem) m_u = tf.expand_dims(m_u, -1) return m_rkn_score[:, 0], m_u[:, 4] def _get_affinity_matrix(self, ref, tar): # (bs, h*w + m, k), (bs, h*w, k) ref_transpose = tf.transpose(ref, [0, 2, 1]) inner_product = tar @ ref_transpose similarity = tf.nn.softmax(inner_product, -1) return similarity # (bs, h*w, h*w+m) def _get_output_patch(self, m_k, k_next, m_v): # (bs, m, kernel**2, k), (bs, h*w, k) m_k_patch_center = m_k[:, :, (self.kernel**2) // 2 + 1, :] ref_transpose = tf.transpose(m_k_patch_center, [0, 2, 1]) # (bs, k, m) inner_product = k_next @ ref_transpose max_patch = tf.argmax(inner_product, -1) #out_arr = [] #k_next = tf.unstack(tf.expand_dims(k_next, -2), num=4096, axis=1) #max_patch = tf.unstack(max_patch, num=4096, axis=1) #m_k = tf.transpose(m_k, [0, 1, 3, 2]) #for i in range(4096): # m_k_one_patch = tf.gather(m_k, max_patch[i], batch_dims=1, axis=1) # m_v_one_patch = tf.gather(m_v, max_patch[i], batch_dims=1, axis=1) # sim = tf.nn.softmax(k_next[i] @ m_k_one_patch) # (bs, 1, 225) # out_v = sim @ m_v_one_patch # out_arr.append(out_v) #output_i = tf.stack(out_arr, axis=1) m_k_one_patch = tf.gather(m_k, max_patch, batch_dims=1, axis=1) # (bs, hw, kernel**2, 256) m_v_one_patch = tf.gather(m_v, max_patch, batch_dims=1, axis=1) inner_product = tf.expand_dims(k_next, -2) @ tf.transpose( m_k_one_patch, [0, 1, 3, 2] ) # (bs, hw, 1, 256) @ (bs, hw, 256, kernel**2) = (bs, hw, 1, kernel**2) similarity = tf.nn.softmax(inner_product, -1) output_i = similarity @ m_v_one_patch # (bs, hw, 1, kernel**2) @ (bs, hw, kernel**2, 3) return output_i # (bs, h*w, h*w+m) def set_coef_memory(self, coef_memory): if coef_memory < 0: self.coef_memory = 0 elif coef_memory > 1: self.coef_memory = 1 def set_dropout_seq(self, dropout_seq): if dropout_seq < 0: self.dropout_seq = 0 elif dropout_seq > 1: self.dropout_seq = 1 return dropout_seq def log_normal_pdf(self, sample, mean, logvar): log2pi = tf.math.log(2. * np.pi) return -.5 * ((sample - mean)**2. * tf.exp(-logvar) + logvar + log2pi) def compute_loss(self, inputs): seq_size = inputs.shape[1] H = inputs.shape[2] W = inputs.shape[3] cv = inputs.shape[4] h = H // 4 w = W // 4 v = tf.reshape(inputs, [-1, H, W, cv]) v_input = tf.image.resize(v, [h, w]) v_input = tf.reshape(v_input, [-1, seq_size, h, w, cv]) #output_v, v_j, _ = self.call((inputs, v), training=True) output_v, v_j, _ = self.call((inputs, v_input), training=True) output_v = tf.reshape(output_v, [-1, h, w, cv]) output_v = tf.image.resize(output_v, [H, W]) output_v = tf.reshape(output_v, [-1, seq_size - 1, H, W, cv]) v = tf.reshape(v, [-1, seq_size, H, W, cv])[:, 1:] #rkn_k, k = self.call_RKN((inputs, v), training=True) #rkn_score, m_rkn_score = self.call_Score((inputs, v), training=True) abs = tf.math.abs(output_v - v) loss = tf.reduce_mean(tf.where(abs < 1., 0.5 * abs * abs, abs - 0.5)) #loss = -tf.reduce_mean(self.log_normal_pdf(rkn_k, k, tf.math.log(0.001))) #loss = tf.reduce_mean(tf.square(rkn_score - m_rkn_score)) return loss def compute_accuracy(self, inputs): seq_size = inputs.shape[1] H = inputs.shape[2] W = inputs.shape[3] cv = inputs.shape[4] h = H // 4 w = W // 4 v_input = tf.reshape(inputs, [-1, H, W, cv]) v_input = tf.image.resize(v_input, [h, w]) v_input = tf.reshape(v_input, [-1, seq_size, h, w, cv]) #output_v, v_j, _ = self.call((inputs, v), training=False) output_v, v_j, _ = self.call((inputs, v_input), training=False) output_v = tf.reshape(output_v, [-1, h, w, cv]) output_v = tf.image.resize(output_v, [H, W]) output_v = tf.reshape(output_v, [-1, seq_size - 1, H, W, cv]) v = tf.reshape(inputs, [-1, seq_size, H, W, cv])[:, 1:] #rkn_k, k = self.call_RKN((inputs, v), training=False) #rkn_score, m_rkn_score = self.call_Score((inputs, v), training=False) return tf.reduce_mean(tf.square(output_v - v)) def compute_apply_gradients(self, x, optimizer): with tf.GradientTape() as tape: loss = self.compute_loss(x) #gradients = tape.gradient(loss, self.trainable_variables) gradients = tape.gradient(loss, self.resnet.trainable_variables) #optimizer.apply_gradients(zip(gradients, self.trainable_variables)) optimizer.apply_gradients( zip(gradients, self.resnet.trainable_variables)) return loss def reconstruct_ResNet(self, inputs, training=True): seq_size = inputs.shape[1] H = inputs.shape[2] W = inputs.shape[3] cv = inputs.shape[4] h = H // 4 w = W // 4 v = tf.reshape(inputs, [-1, H, W, cv]) v = tf.image.resize(v, [h, w]) v = tf.reshape(v, [-1, seq_size, h, w, cv]) output_v, v_j, v_0 = self.call_ResNet_Local((inputs, v), training=training) return output_v, v_j, v_0 def reconstruct(self, inputs, v_inputs=None, training=True, keep=False): seq_size = inputs.shape[1] H = inputs.shape[2] W = inputs.shape[3] c_inp = inputs.shape[4] h = H // 4 w = W // 4 if v_inputs is None: cv = c_inp v = tf.reshape(inputs, [-1, H, W, cv]) v_input = tf.image.resize(v, [h, w]) v_input = tf.reshape(v_input, [-1, seq_size, h, w, cv]) else: cv = v_inputs.shape[4] v = tf.reshape(v_inputs, [-1, H, W, cv]) v_input = tf.image.resize(v, [h, w], 'nearest') v_input = tf.reshape(v_input, [-1, seq_size, h, w, cv]) output_v, v_j, drop_out = self.call((inputs, v_input), training=training, keep=keep) drop_out = tf.reshape(drop_out, [-1, seq_size, H, W, c_inp]) output_v = tf.reshape(output_v, [-1, h, w, cv]) output_v = tf.image.resize(output_v, [H, W]) output_v = tf.reshape(output_v, [-1, seq_size - 1, H, W, cv]) v = tf.reshape(v, [-1, seq_size, H, W, cv]) return output_v, v, drop_out def reset_mem(self): self.mem_write = True self.mem0 = None self.mem5 = None self.k0 = None self.v0 = None self.last_v = None