def forward(self,input_image): #subsample subsampled_input = fourierOperations.Full_Map(input_image,self.mask) #reconstruct output_images = self.proximal(subsampled_input,self.mask) return output_images
def forward(self,input_image,mask): # Creat state and image 0 state_0 = self.alpha0(input_image) image_0 = self.prox0(state_0) # Creat state and image 1 image_0_sampled = fourierOperations.Full_Map(image_0,mask) state_1 = image_0 - self.alpha1(image_0_sampled) + state_0 image_1 = self.prox1(state_1) # Creat state and image 2 image_1_sampled = fourierOperations.Full_Map(image_1,mask) state_2 = image_1 - self.alpha2(image_1_sampled) + state_0 image_2 = self.prox2(state_2) # output the image return image_2
def forward(self,input_image): # get the mask from LOUPE mask = self.loupe(self.logit_parameter) mask = self.expandSampleMatrix(mask) #subsample subsampled_input = fourierOperations.Full_Map(input_image,mask) #reconstruct output_images = self.proximal(subsampled_input,mask) return output_images
def forward(self,input_image): # clear sampling memory self.DPS.initialize_sample_memory() #check if the 30 DC lines are sampled if self.Pineda == True: lines = torch.arange(30)-15 self.DPS.sample_memory[:,lines] = 1 # create an initial hidden state and context h_var = torch.zeros(1,self.batch_size,self.lstm_size).to(self.device) c_var = torch.zeros(1,self.batch_size,self.lstm_size).to(self.device) # initialize output images output_images = torch.zeros(self.batch_size,self.no_iter,self.width,self.height).to(self.device) current_output_image = 0 # loop over all the iterations for i in range(self.no_iter): # creat the sampling mask from the logits logits = self.final_fc[i](h_var.reshape(self.batch_size,self.lstm_size)) mask = self.DPS(logits) mask = self.expandSampleMatrix(mask) # begin by subsampling the input image, subsampled_input = fourierOperations.Full_Map(input_image,mask) # proximal mapping current_output_image = self.proximal[i](subsampled_input,mask) #save the current output image output_images[:,i,:,:] = current_output_image[:,0,:,:] # create the next hidden state input_lstm = self.SampleNet[i](current_output_image).unsqueeze(0) _,(h_var,c_var) = self.lstm[i](input_lstm,(h_var,c_var)) # output the images return output_images
def forward(self,input_image): # clear sampling memory self.DPS.initialize_sample_memory() #check if the 30 DC lines are sampled if self.Pineda == True: lines = torch.arange(30)-15 self.DPS.sample_memory[:,lines] = 1 # unsqueeze the logits along the batch dimension logits = self.logit_parameter.unsqueeze(0).repeat(self.batch_size,1) # use the logits to create a sampling mask mask = self.DPS(logits) mask = self.expandSampleMatrix(mask) #subsample subsampled_input = fourierOperations.Full_Map(input_image,mask) #reconstruct output_images = self.proximal(subsampled_input,mask) return output_images