def forward(self,input_image):
     #subsample
     subsampled_input = fourierOperations.Full_Map(input_image,self.mask)
     
     #reconstruct
     output_images = self.proximal(subsampled_input,self.mask)
     
     return output_images
 def forward(self,input_image,mask):
     # Creat state and image 0
     state_0 = self.alpha0(input_image)
     image_0 = self.prox0(state_0)
     
     # Creat state and image 1
     image_0_sampled = fourierOperations.Full_Map(image_0,mask)    
     state_1 = image_0 - self.alpha1(image_0_sampled) + state_0
     image_1 = self.prox1(state_1)
     
     # Creat state and image 2
     image_1_sampled = fourierOperations.Full_Map(image_1,mask)    
     state_2 = image_1 - self.alpha2(image_1_sampled) + state_0
     image_2 = self.prox2(state_2)
    
     # output the image
     return image_2
 def forward(self,input_image):
     # get the mask from LOUPE
     mask = self.loupe(self.logit_parameter)
     mask = self.expandSampleMatrix(mask)
     
     #subsample
     subsampled_input = fourierOperations.Full_Map(input_image,mask)
     
     #reconstruct
     output_images = self.proximal(subsampled_input,mask)
     
     return output_images
 def forward(self,input_image):
     # clear sampling memory
     self.DPS.initialize_sample_memory()
     
     #check if the 30 DC lines are sampled
     if self.Pineda == True:
         lines = torch.arange(30)-15
         self.DPS.sample_memory[:,lines] = 1
         
     # create an initial hidden state and context
     h_var = torch.zeros(1,self.batch_size,self.lstm_size).to(self.device)
     c_var = torch.zeros(1,self.batch_size,self.lstm_size).to(self.device)
     
     # initialize output images
     output_images = torch.zeros(self.batch_size,self.no_iter,self.width,self.height).to(self.device)
     current_output_image = 0
     
     # loop over all the iterations
     for i in range(self.no_iter):
         # creat the sampling mask from the logits
         logits = self.final_fc[i](h_var.reshape(self.batch_size,self.lstm_size))
         mask = self.DPS(logits)
      
         mask = self.expandSampleMatrix(mask)
         
         # begin by subsampling the input image,
         subsampled_input = fourierOperations.Full_Map(input_image,mask)
         
         # proximal mapping
         current_output_image = self.proximal[i](subsampled_input,mask)
                      
         #save the current output image
         output_images[:,i,:,:] = current_output_image[:,0,:,:]
         
         # create the next hidden state
         input_lstm = self.SampleNet[i](current_output_image).unsqueeze(0)
         _,(h_var,c_var) = self.lstm[i](input_lstm,(h_var,c_var))
    
     # output the images
     return output_images
 def forward(self,input_image):
     # clear sampling memory
     self.DPS.initialize_sample_memory()
     
     #check if the 30 DC lines are sampled
     if self.Pineda == True:
         lines = torch.arange(30)-15
         self.DPS.sample_memory[:,lines] = 1
     
     # unsqueeze the logits along the batch dimension
     logits = self.logit_parameter.unsqueeze(0).repeat(self.batch_size,1)
     
     # use the logits to create a sampling mask
     mask = self.DPS(logits)
     mask = self.expandSampleMatrix(mask)
     
     #subsample
     subsampled_input = fourierOperations.Full_Map(input_image,mask)
     
     #reconstruct
     output_images = self.proximal(subsampled_input,mask)
     
     return output_images