def forward(self, x): x1 = self.conv1(x) # [B, C, H, W] w = F.mean(x1, axis=-1, keepdims=False) # [B,C,H] w = F.mean(w, axis=-1, keepdims=False) # [B,C] w = self.linear(w) w = F.add_axis(w, axis=-1) w = F.add_axis(w, axis=-1) # [B,C,1,1] x1 = F.concat((x1, F.multiply(x1, w)), axis=1) # [B, 2C, H, W] del w x1 = self.conv2(x1) # [B, C, H, W] return self.lrelu(x + x1)
def test_multiply(): assertTensorClose(F.multiply(-3.0, -4.0), np.multiply(np.float32(-3.0), np.float32(-4.0))) assertTensorClose( F.multiply(tensor([3.0, 4.0]), 4.0).numpy(), np.multiply(np.array([3.0, 4.0], dtype=np.float32), 4.0), ) assertTensorClose( F.multiply(4.0, tensor([3.0, 4.0])).numpy(), np.multiply(4.0, np.array([3.0, 4.0], dtype=np.float32)), ) assertTensorClose( F.multiply(tensor([3.0, 4.0]), tensor([3.0, 4.0])).numpy(), np.multiply( np.array([3.0, 4.0], dtype=np.float32), np.array([3.0, 4.0], dtype=np.float32), ), )
def forward(self, now_LR, pre_h_SD): """ now_LR: B,3,H,W pre_h_SD: B,48,H,W """ batch, C, H, W = pre_h_SD.shape kernels = self.conv(now_LR) # [B, k*k, H, W] batchwise_ans = [] for idx in range(batch): kernel = kernels[idx] # [k*k, H, W] kernel = F.dimshuffle(kernel, (1, 2, 0)) # [H, W , k*k] kernel = F.reshape(kernel, (H, W, 1, self.K, self.K, 1)) kernel = F.broadcast_to(kernel, (C, H, W, 1, self.K, self.K, 1)) batchwise_ans.append( F.local_conv2d( F.add_axis(pre_h_SD[idx], 0), kernel, [1, 1], [1, 1], [1, 1])) # [1, C, H, W] some bug with padding similarity_matrix = F.concat(batchwise_ans, axis=0) # [B,C,H,W] del batchwise_ans similarity_matrix = F.sigmoid(similarity_matrix) return F.multiply(pre_h_SD, similarity_matrix)
def forward(self, now_LR, pre_h_SD): """ now_LR: B,3,H,W pre_h_SD: B,64,H,W """ pad = self.K // 2 batch, C, H, W = pre_h_SD.shape kernels = self.conv(now_LR) # [B, k*k, H, W] # 对 pre_h_SD进行padding similarity_matrix = F.zeros_like(pre_h_SD) pre_h_SD = add_H_W_Padding(pre_h_SD, margin=pad) for i in range(self.K): for j in range(self.K): # 做点乘 kernel = kernels[:, i * self.K + j, :, :] # [B, H, W] kernel = F.add_axis(kernel, axis=1) # [B, 1 ,H, W] kernel = F.broadcast_to(kernel, [batch, C, H, W]) corr = kernel * pre_h_SD[:, :, i:(H + i), j:(W + j)] similarity_matrix = similarity_matrix + corr # [B, C, H, W] similarity_matrix = F.sigmoid(similarity_matrix) return F.multiply(pre_h_SD[:, :, pad:(H + pad), pad:(W + pad)], similarity_matrix)