def HT_y(y, sf, fft_BT): if len(y.shape) == 3: ch, w, h = y.shape # z = torch.zeros([ch, w*sf ,h*sf]) # z[:,::sf, ::sf] = y z = F.pad(y, [0, 0, 0, 0, 0, sf * sf - 1], "constant", value=0) z = F.pixel_shuffle(z, upscale_factor=sf).view(bs, ch, w * sf, h * sf) f = torch.rfft(z, 2, onesided=False) fft_BT = fft_BT.unsqueeze(0).repeat(ch, 1, 1, 1) M = torch.cat(((f[:, :, :, 0] * fft_B[:, :, :, 0] - f[:, :, :, 1] * fft_B[:, :, :, 1]).unsqueeze(3), (f[:, :, :, 0] * fft_B[:, :, :, 1] + f[:, :, :, 1] * fft_B[:, :, :, 0]).unsqueeze(3)), 3) Hz = torch.irfft(M, 2, onesided=False) elif len(y.shape) == 4: bs, ch, w, h = y.shape # z = torch.zeros([bs ,ch ,sf*w ,sf*w]) # z[:,:,::sf,::sf] = y z = y.view(-1, 1, w, h) z = F.pad(z, [0, 0, 0, 0, 0, sf * sf - 1, 0, 0], "constant", value=0) z = F.pixel_shuffle(z, upscale_factor=sf).view(bs, ch, w * sf, h * sf) f = torch.rfft(z, 2, onesided=False) fft_BT = fft_BT.unsqueeze(0).unsqueeze(0).repeat(bs, ch, 1, 1, 1) M = torch.cat( ((f[:, :, :, :, 0] * fft_BT[:, :, :, :, 0] - f[:, :, :, :, 1] * fft_BT[:, :, :, :, 1]).unsqueeze(4), (f[:, :, :, :, 0] * fft_BT[:, :, :, :, 1] + f[:, :, :, :, 1] * fft_BT[:, :, :, :, 0]).unsqueeze(4)), 4) Hz = torch.irfft(M, 2, onesided=False) return Hz
def img_sum(self, input_patch, input_img, kernel_size, weights_out): output = torch.zeros(input_img.size()).to(input_img.device) if (weights_out is None) or (weights_out.size(2) != input_img.size(2)) or ( weights_out.size(3) != input_img.size(3)): weights = torch.ones(1, kernel_size * kernel_size, input_patch.size(2), input_patch.size(3)) weights_output = torch.zeros(1, 1, input_img.size(2), input_img.size(3)) else: weights_output = weights_out for i in range(kernel_size): for j in range(kernel_size): in_x = input_patch[:, :, i::kernel_size, j::kernel_size] in_x = F.pixel_shuffle(in_x, kernel_size) output[:, :, i:i + in_x.size(2), j:j + in_x.size(3)] += in_x if (weights_out is None) or ( weights_out.size(2) != input_img.size(2)) or ( weights_out.size(3) != input_img.size(3)): wei_x = weights[:, :, i::kernel_size, j::kernel_size] wei_x = F.pixel_shuffle(wei_x, kernel_size) weights_output[:, :, i:i + wei_x.size(2), j:j + wei_x.size(3)] += wei_x weights_output = weights_output.to(output.device) output = output / weights_output return output, weights_output
def forward(self, in_tensor): N, C, H, W = in_tensor.size() kernel_tensor = self.down(in_tensor) kernel_tensor = self.encoder(kernel_tensor) kernel_tensor = F.pixel_shuffle(kernel_tensor, self.delta) kernel_tensor = F.softmax(kernel_tensor, dim=1) kernel_tensor = kernel_tensor.unfold(2, self.delta, step=self.delta) kernel_tensor = kernel_tensor.unfold(3, self.delta, step=self.delta) kernel_tensor = kernel_tensor.reshape(N, self.Kup**2, H, W, self.delta**2) kernel_tensor = kernel_tensor.permute(0, 2, 3, 1, 4) in_tensor = F.pad(in_tensor, pad=(self.Kup // 2, self.Kup // 2, self.Kup // 2, self.Kup // 2), mode='constant', value=0) in_tensor = in_tensor.unfold(dimension=2, size=self.Kup, step=1) in_tensor = in_tensor.unfold(3, self.Kup, step=1) in_tensor = in_tensor.reshape(N, C, H, W, -1) in_tensor = in_tensor.permute(0, 2, 3, 1, 4) out_tensor = torch.matmul(in_tensor, kernel_tensor) out_tensor = out_tensor.reshape(N, H, W, -1) out_tensor = out_tensor.permute(0, 3, 1, 2) out_tensor = F.pixel_shuffle(out_tensor, self.delta) out_tensor = self.out(out_tensor) return out_tensor
def forward(self, input): x = input res = x x = self.relu(self.conv1(x)) x = F.pixel_shuffle(x, 2) x = F.pixel_shuffle(self.upsample(res), 2) + self.conv2(x) return self.relu(x)
def forward(self, z): # if self.up_type == 'shuffle': for il in np.arange(self.gen_layer_num): if il == 0: z = self.convs[il](z.view(-1, self.nz)) z = z.view(-1, self.ngf * (2**(self.gen_layer_num - 2 - il)), 8, 8) z = functional.relu(self.BNs[il](z)) elif il == self.gen_layer_num - 1: z = self.convs[il](z) z = functional.tanh(z) z = functional.pixel_shuffle(z, 2) else: z = self.convs[il](z) z = functional.relu(self.BNs[il](z)) z = functional.pixel_shuffle(z, 2) else: for il in range(self.gen_layer_num): if il == self.gen_layer_num - 1: z = functional.tanh(self.convs[il](z)) else: z = functional.relu(self.BNs[il](self.convs[il](z))) return z
def forward(self, x): b, c, h, w = x.size() x = fill(x) top_x = self.down4(x) bottom_x = self.down2(x) top_x = self.top1(top_x) top_x = self.top2(top_x) top_x = self.top3(top_x) top_x = F.pixel_shuffle(top_x, 2) bottom_x = self.bottom1(bottom_x) bottom_x = torch.cat((bottom_x, top_x), 1) bottom_x = self.bottom_gate(bottom_x) bottom_x = self.bottom2(bottom_x) bottom_x = self.bottom3(bottom_x) bottom_x = F.pixel_shuffle(bottom_x, 2) x = self.main1(x) x = torch.cat((x, bottom_x), 1) x = self.main_gate(x) x = self.main2(x) x = self.main3(x) x = self.end(x) x = x[:, :, :h, :w] return x
def forward(self, x, h0, h1, h2, h3, h4): x = self.conv1(x) h0_new = self.rnn0(x, h0) x = h0_new[0] x = F.pixel_shuffle(x, 2) h1_new = self.rnn1(x, h1) x = h1_new[0] x = F.pixel_shuffle(x, 2) h2_new = self.rnn2(x, h2) x = h2_new[0] x = F.pixel_shuffle(x, 2) h3_new = self.rnn3(x, h3) x = h3_new[0] x = F.pixel_shuffle(x, 2) h4_new = self.rnn4(x, h4) x = h4_new[0] x = F.pixel_shuffle(x, 2) x = self.conv2(x) x = F.tanh(x) / 2 return x, h0_new, h1_new, h2_new, h3_new, h4_new
def forward(self, x, HR): B, C, T, H, W = x.shape x = F.relu(self.bn3d_1(self.conv3d_1(x))) x = F.relu(self.bn3d_2(self.conv3d_2(x))) x = F.relu(self.bn3d_2_1(self.conv3d_2_1(x))) x = self.conv3d_2_2(x) x = self.head(x) x = self.middle_1(x) x = self.middle_2(x) x = self.middle_3(x) x = self.middle_4(x) x, x_att = self.last(x) x = self.fusion_head(x) x = self.Fusion_last(x) x = self.compress(F.relu(self.bn3d_2_2(x))) x = self.middle_6(x) x = self.middle_7(x) x = self.middle_8(x) x = self.middle_9(x) x = self.middle_10(x) x = self.middle_11(x) x = self.middle_12(x) x = F.relu(self.conv3d_3(F.relu(self.bn3d_3(x)))) Rx = F.relu(self.conv3d_r1(x)) Rx = self.conv3d_r2(Rx) Rx = F.relu(F.pixel_shuffle(Rx.squeeze_(2), 2)) Rx = torch.unsqueeze(Rx, dim=2) Rx = F.relu(self.conv3d_r4(Rx)) Rx = self.conv3d_r3(Rx) out = HR + F.pixel_shuffle(Rx.squeeze_(2), 2) return out
def forward(self, xs): x1,x2,x3,x4 = xs x1 = self.dp1(x1) x3 = self.dp3(x3) x4 = self.dp4(x4) t1 = self.prelu1(self.conv1(self.arm1(x1))) t2 = self.prelu2(self.conv2(self.arm2(x2))) t3 = self.prelu3(self.conv3(self.arm3(x3))) t4 = self.prelu4(self.conv4(self.arm4(x4))) # s1 = t1 s1 = F.pixel_shuffle(t1, upscale_factor=2) s2 = F.pixel_shuffle(t2, upscale_factor=4) s3 = F.pixel_shuffle(t3, upscale_factor=8) s4 = F.pixel_shuffle(t4, upscale_factor=16) fusion = self.fusenet(s1,s2,s3,s4) cls = self.conv_cls(fusion) cks = self.conv_ck(fusion) dist = self.dist_conv(fusion) output = {'cls':cls,'cks':cks,'dist':dist} return output
def forward(self, ims, tList): # shape of ims : list of input images [[B,C,H,W], ...] # shape of tList : list of target time index (e.g. [1/4, 2/4, 3/4]) b,c,h,w = ims[0].size() outs = torch.zeros([len(tList)+1, b,c, h*self.sf, w*self.sf]).cuda() # Get feature representation of each images enc_s = [] for i in range((len(ims))): s = self.encoder(ims[i]) enc_s.append(s) # Fuse or merge feautres using EFST enc_sf = self.efst(enc_s) # Spatial decoder dec_feat, rimg = self.decoder(enc_s[3], enc_sf) rimg = F.pixel_shuffle(rimg, self.sf) out = F.upsample(ims[3], scale_factor= self.sf, mode='bilinear') + rimg outs[0,:] = out # Flow estimator uI3 = F.upsample(ims[3], scale_factor=self.sf, mode='bilinear') uI4 = F.upsample(ims[4], scale_factor=self.sf, mode='bilinear') flow34 = self.pwcnet(uI3, uI4) flow43 = self.pwcnet(uI4, uI3) for l in range(len(tList)): featI = [] t = tList[l] flowt0 = -t*(1-t)*flow34 + t*t*flow43 flowt1 = (1-t)*(1-t)*flow34 -t*(1-t)*flow43 # Feature interpolation network for i in range(len(enc_s[3])): fi = self.fi(enc_s[3][i], enc_s[4][i], flowt0, flowt1) featI.append(fi) # Generate LR intermediate frames dwI = (warp(ims[3], flowt0) + warp(ims[4], flowt1))/2. # Spatio-temporal decoder _, trimg = self.decoder(featI, dec_feat) trimg = F.pixel_shuffle(trimg, self.sf) wI = F.upsample(dwI, scale_factor=self.sf, mode='bilinear') out = wI + trimg outs[l+1, :] = out return outs
def forward(self, x): cat_feats, out = self.backbone(x) msfe_out = self.msfe(cat_feats) body = self.relu(self.bn1(self.conv1(msfe_out))) edge = self.sigmoid(self.bn2(self.conv2(msfe_out))) final_body = F.pixel_shuffle(body, upscale_factor=self.scale_factor) final_edge = F.pixel_shuffle(edge, upscale_factor=self.scale_factor) return final_edge, final_body
def forward(self, input, h_1, h_2, h_3, h_4): h_1 = self.rnn1(self.conv1(input), h_1) h_2 = self.rnn2(F.pixel_shuffle(h_1[0], 2), h_2) h_3 = self.rnn3(F.pixel_shuffle(h_2[0], 2), h_3) h_4 = self.rnn4(F.pixel_shuffle(h_3[0], 2), h_4) return torch.tanh(self.conv2(F.pixel_shuffle(h_4[0], 2))), h_1, h_2, h_3, h_4
def forward(self, x): x = self.indexnet(x) y = torch.sigmoid(x) z = F.softmax(y, dim=1) idx_en = F.pixel_shuffle(z, 2) idx_de = F.pixel_shuffle(y, 2) return idx_en, idx_de
def forward(self, input, iter): output = self.conv1(input) output = self.rnn1(output, iter) output = F.pixel_shuffle(output, 2) output = self.rnn2(output, iter) output = F.pixel_shuffle(output, 2) output = self.rnn3(output, iter) output = F.pixel_shuffle(output, 2) output = self.rnn4(output, iter) output = F.pixel_shuffle(output, 2) output = self.conv2(output) return output
def forward(self, noise): out = self.fc_1(noise) out = F.leaky_relu(out, 0.2) out = self.fc_2(out) out = F.leaky_relu(out, 0.2) out = out.view([-1, 128, 7, 7]) out = F.pixel_shuffle(out, 2) out = self.conv_3(out) out = self.bn_3(out) out = F.leaky_relu(out, 0.2) out = F.pixel_shuffle(out, 2) out = self.conv_4(out) out = F.tanh(out) return out
def forward(self, x): x = F.pixel_shuffle(x, 2) x = self.conv1(x) x = F.pixel_shuffle(x, 2) x = self.conv2(x) x = F.pixel_shuffle(x, 2) x = self.conv3(x) x = F.pixel_shuffle(x, 2) x = self.conv4(x) x = F.pixel_shuffle(x, 2) return x
def forward(self, input, hidden1, hidden2, hidden3, hidden4, unet_output1, unet_output2, wdec): init_conv, rnn1_i, rnn1_h, rnn2_i, rnn2_h, rnn3_i, rnn3_h, rnn4_i, rnn4_h, final_conv = wdec init_conv = init_conv + self.conv1.weight x = F.conv2d(input, init_conv, stride=1, padding=0) # x = self.conv1(input) hidden1 = self.rnn1(x, rnn1_i, rnn1_h, hidden1) # rnn 2 x = hidden1[0] x = F.pixel_shuffle(x, 2) if self.v_compress and self.fuse_level >= 3: x = torch.cat([x, unet_output1[0], unet_output2[0]], dim=1) hidden2 = self.rnn2(x, rnn2_i, rnn2_h, hidden2) # rnn 3 x = hidden2[0] x = F.pixel_shuffle(x, 2) if self.v_compress and self.fuse_level >= 2: x = torch.cat([x, unet_output1[1], unet_output2[1]], dim=1) hidden3 = self.rnn3(x, rnn3_i, rnn3_h, hidden3) # rnn 4 x = hidden3[0] x = F.pixel_shuffle(x, 2) if self.v_compress: x = torch.cat([x, unet_output1[2], unet_output2[2]], dim=1) hidden4 = self.rnn4(x, rnn4_i, rnn4_h, hidden4) # final x = hidden4[0] x = F.pixel_shuffle(x, 2) final_conv = final_conv + self.conv2.weight x = F.conv2d(x, final_conv, stride=1, padding=0) x = F.tanh(x) / 2 return x, hidden1, hidden2, hidden3, hidden4
def forward(self, input, shape): N, C, L, H, W = shape out = self.lrule(self.conv(input)) out = out.permute(0, 2, 1, 3, 4).reshape(N * L, -1, H, W) out = F.pixel_shuffle(out, 2) out = out.reshape(N, L, -1, H * 2, W * 2).permute(0, 2, 1, 3, 4) return out
def forward(self, x): #x = space_to_depth(x) I = torch.cat((0.8 * x, x, 1.2 * x, 1.5 * x), dim=1) inc = self.inc(I) ''' layer1 = self.layer1(space_to_depth(inc)) layer2 = self.layer2(space_to_depth(layer1)) layer3 = self.layer3(space_to_depth(layer2)) ''' layer1 = self.layer1(inc) layer2 = self.layer2(layer1) layer3 = self.layer3(layer2) #global_feature = self.global_feature(layer3) #inc = self.fusionblock0(global_feature,inc) #layer1 = self.fusionblock1(global_feature,layer1) #layer2 = self.fusionblock2(global_feature,layer2) #layer3 = self.fusionblock3(global_feature,layer3) #inter = self.inter(space_to_depth(layer3)) #up0 = self.up0(inter) #inter_layer = torch.cat((up0,layer3),dim=1) #inter_layer = self.inter_layer(inter_layer) up1 = self.up1(layer3) layer4 = torch.cat((up1, layer2), dim=1) layer4 = self.layer4(layer4) up2 = self.up2(layer4) layer5 = torch.cat((up2, layer1), dim=1) layer5 = self.layer5(layer5) up3 = self.up3(layer5) layer6 = torch.cat((up3, inc), dim=1) layer6 = self.layer6(layer6) output = self.output(layer6) output = F.pixel_shuffle(output, 2) return output
def forward(self, x): ''' x: [B, T, C, H, W], T = 7. reshape to [B, C, T, H, W] for Conv3D Generate filters and image residual: Fx: [B, 25, 16, H, W] for DynamicUpsamplingFilter_3C Rx: [B, 3*16, 1, H, W] ''' B, T, C, H, W = x.size() x = x.permute(0, 2, 1, 3, 4) # [B,C,T,H,W] for Conv3D x_center = x[:, :, T // 2, :, :] x = self.conv3d_1(x) x = self.dense_block_1(x) x = self.dense_block_2(x) x = F.relu(self.conv3d_2(F.relu(self.bn3d_2(x), inplace=True)), inplace=True) # image residual Rx = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # [B, 3*16, 1, H, W] # filter Fx = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) # [B, 25*16, 1, H, W] Fx = F.softmax(Fx.view(B, 25, self.scale**2, H, W), dim=1) # Adapt to official model weights if self.adapt_official: adapt_official(Rx, scale=self.scale) # dynamic filter out = self.dynamic_filter(x_center, Fx) # [B, 3*R, H, W] out += Rx.squeeze_(2) out = F.pixel_shuffle(out, self.scale) # [B, 3, H, W] return out
def forward(self,x): data_shape = x.size() x = self.prepLayer(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) #x = self.layer4(x) #out = self.upsample(x, output_size=data_shape) out = self.upsample(x) out = F.pixel_shuffle(out, 2) # rather than probabilities we are making it a hard mask prediction # this is not it, we can restore binary logic later out = F.relu(self.bn1(self.conv1(out))) out = F.relu(self.bn2(self.conv2(out))) out = self.conv3(out) outshape = out.size() # min max scaling y = out.view(outshape[0], outshape[1], -1) y = y - y.min(2, keepdim=True)[0] y = y/(y.max(2, keepdim=True)[0] ) y = y.view(outshape) #mask = mask.float() # cast back to float sicne x is a ByteTensor now return y
def forward(self, x): conv1 = self.conv1(x) pool1 = F.max_pool2d(conv1, kernel_size=2) conv2 = self.conv2(pool1) pool2 = F.max_pool2d(conv2, kernel_size=2) conv3 = self.conv3(pool2) pool3 = F.max_pool2d(conv3, kernel_size=2) conv4 = self.conv4(pool3) pool4 = F.max_pool2d(conv4, kernel_size=2) conv5 = self.conv5(pool4) up6 = self.up6(conv5) up6 = torch.cat([up6, conv4], 1) conv6 = self.conv6(up6) up7 = self.up7(conv6) up7 = torch.cat([up7, conv3], 1) conv7 = self.conv7(up7) up8 = self.up8(conv7) up8 = torch.cat([up8, conv2], 1) conv8 = self.conv8(up8) up9 = self.up9(conv8) up9 = torch.cat([up9, conv1], 1) conv9 = self.conv9(up9) conv10 = self.conv10(conv9) out = F.pixel_shuffle(conv10, 2) return out
def forward(self, sample): out = self.preprocess(sample).view(-1, 512, 4, 4) for idx, block in enumerate(self.blocks): pos = torch.arange(out.size(-1), dtype=out.dtype, device=out.device) / 100 pos = pos[None].expand(out.size(0), out.size(-1)) sym = (pos[:, None, :] + pos[:, :, None]) / 2 asym = (pos[:, None, :] - pos[:, :, None]) / 2 combined = torch.cat((out, sym[:, None], asym[:, None]), dim=1) out = block(combined) out = func.pixel_shuffle(out, 2) #out = func.interpolate(out, scale_factor=2) mask = torch.arange(out.size(-1), device=out.device) mask = (mask[:, None] - mask[None, :]) > 0 mask = mask.float() distances = func.softplus(self.distances(out)) distances = (distances + distances.permute(0, 1, 3, 2)) / 2 rotation = self.rotation(out) rotation = mask[None, None] * rotation + (1 - mask[None, None]) * rotation.permute(0, 1, 3, 2) #rotation = rotation + rotation.permute(0, 1, 3, 2) rotation = rotation.sin() / (rotation.sin().norm(dim=1, keepdim=True).detach() + 1e-6) direction = self.direction(out) direction = mask[None, None] * direction.permute(0, 1, 3, 2) + (1 - mask[None, None]) * direction #direction = direction + direction.permute(0, 1, 3, 2) direction = direction.sin() / (direction.sin().norm(dim=1, keepdim=True).detach() + 1e-6) #rotation, direction = self.predict_rotation(out) size = out.size(-1) ind = torch.arange(size, device=out.device) distances[:, :, ind, ind] = 0 out = torch.cat((distances, rotation, direction), dim=1) return (out,)
def forward(self, x): x = self.preprocess(x) if self.preprocess else x x = self.trns(x) x = torch.unsqueeze(x, 2) x = torch.unsqueeze(x, 2) x = F.pixel_shuffle(x, 2) return x
def upsample(img, scale, border='reflect'): """Bicubical upsample via **CONV2D**. Using PIL's kernel. Args: img: a tf tensor of 2/3/4-D. scale: must be integer >= 2. border: padding mode. Recommend to 'REFLECT'. """ device = img.device kernels, s = weights_upsample(scale) if s == 1: return img # bypass kernels = [k.astype('float32') for k in kernels] kernels = [torch.from_numpy(k) for k in kernels] p1 = 1 + s // 2 p2 = 3 img, shape = _push_shape_4d(img) img_ex = F.pad(img, [p1, p2, p1, p2], mode=border) c = img_ex.shape[1] assert c is not None, "img must define channel number" c = int(c) filters = [ torch.reshape(torch.eye(c, c), [c, c, 1, 1]) * k for k in kernels ] weights = torch.stack(filters, dim=0).transpose(0, 1).reshape([-1, c, 5, 5]) img_s = F.conv2d(img_ex, weights.to(device)) img_s = F.pixel_shuffle(img_s, s) more = s // 2 * s crop = slice(more - s // 2, -(s // 2)) img_s = _pop_shape(img_s[..., crop, crop], shape) return img_s
def forward(self, x): """ Args: x (Tensor): Input with shape (b, 7, c, h, w) Returns: Tensor: Output with shape (b, 1, h * scale, w * scale) """ num_batches, num_imgs, _, h, w = x.size() x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D x_center = x[:, :, num_imgs // 2, :, :] x = self.conv3d1(x) x = self.dense_block1(x) x = self.dense_block2(x) x = F.relu(self.bn3d2(x), inplace=True) x = F.relu(self.conv3d2(x), inplace=True) # residual image res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True)) # filter filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True)) filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1) # dynamic filter out = self.dynamic_filter(x_center, filter_) out += res.squeeze_(2) out = F.pixel_shuffle(out, self.scale) return out
def sample(self, x): """ Sample from the prior to generate a new datapoint. :param x: tensor representing shape of sample """ h, w = self.x_shape batch_size = x.size(0) h_dec = x.new_zeros((batch_size, self.h_dim, h // 2, w // 2)) c_dec = x.new_zeros((batch_size, self.h_dim, h // 2, w // 2)) canvas = x.new_zeros((batch_size, self.x_dim, h, w)) for _ in range(self.T): p_mu, p_log_std = torch.split(self.prior(h_dec), self.z_dim, dim=1) p_std = torch.exp(p_log_std) z = Normal(p_mu, p_std).sample() canvas_next = self.read_head(canvas) h_dec, c_dec = self.decoder(torch.cat([z, canvas_next], dim=1), [h_dec, c_dec]) canvas = canvas + F.pixel_shuffle(self.write_head(h_dec), 2) return canvas
def forward(self, input, tmp_FLAG=False): tmp_list = [] batch_size, row, col = input.size(0), input.size(2), input.size(3) y = torch.autograd.Variable(torch.zeros(batch_size, 32, row, col)).cuda() c = torch.autograd.Variable(torch.zeros(batch_size, 32, row, col)).cuda() x = self.relu1(self.conv1(input)) x = self.resbk1(x) x = self.relu2(self.conv2(x)) x = self.resbk2(x) # loop mechanism, with conv-lstm at first for loop in range(3): concat_feature = self.relu_concate2( self.conv_concate2( self.resbk_concate( self.relu_concate1( self.conv_concate1(torch.cat([x, y], dim=1)))))) i = self.conv_i(concat_feature) f = self.conv_f(concat_feature) g = self.conv_g(concat_feature) o = self.conv_o(concat_feature) c = f * c + i * g # c: hidden state h = o * torch.tanh(c) # h: LSTM output y = self.mrc(h) + h y_upsampled = F.pixel_shuffle(self.ps2_conv( self.relu_inter_ps2( self.conv_inter_ps2( self.resbk_inter_ps( self.relu_inter_ps1( self.conv_inter_ps1( F.pixel_shuffle(self.ps1_conv(y), upscale_factor=2))))))), upscale_factor=2) output = self.conv_final( self.relu_sr2( self.conv_sr2(self.relu_sr1(self.conv_sr1( y_upsampled))))) + torch.nn.functional.interpolate( input, scale_factor=4, mode=self.interpolate) if tmp_FLAG: tmp_list.append(output) if tmp_FLAG: return output, tmp_list[:-1] else: return output
def forward(self, x): print(x.size()) for layer_idx, conv in enumerate(self.conv_layers): x = same_padding_conv(x, conv) x = F.relu( x) if layer_idx != len(self.conv_layers) - 1 else F.tanh(x) x = F.pixel_shuffle(x, 4) return x
def forward(self, x): # start_time = datetime.datetime.now() x = self.channel_compressor(x) x = self.context_encoder(x) x = F.pixel_shuffle(x, self.enlarge_rate) x = self.kernel_normalizer(x) # print("KP cost:{}".format(datetime.datetime.now() - start_time)) return x
def forward(self, input, hidden1, hidden2, hidden3, hidden4): x = self.conv1(input) hidden1 = self.rnn1(x, hidden1) x = hidden1[0] x = F.pixel_shuffle(x, 2) hidden2 = self.rnn2(x, hidden2) x = hidden2[0] x = F.pixel_shuffle(x, 2) hidden3 = self.rnn3(x, hidden3) x = hidden3[0] x = F.pixel_shuffle(x, 2) hidden4 = self.rnn4(x, hidden4) x = hidden4[0] x = F.pixel_shuffle(x, 2) x = F.tanh(self.conv2(x)) / 2 return x, hidden1, hidden2, hidden3, hidden4