Exemple #1
0
    def _get_training_queue(self, batch_size, num_threads=4):
        dim = self.dim

        file_list = tf.matching_files(
            '/media/data/onn/quickdraw16_192/im_*.png')
        filename_queue = tf.train.string_input_producer(file_list)

        image_reader = tf.WholeFileReader()

        _, image_file = image_reader.read(filename_queue)
        image = tf.image.decode_png(image_file, channels=1, dtype=tf.uint8)
        image = tf.cast(image, tf.float32)  # Shape [height, width, 1]
        image = tf.expand_dims(image, 0)
        image /= 255.

        # Get the ratio of the patch size to the smallest side of the image
        img_height_width = tf.cast(tf.shape(image)[1:3], tf.float32)
        size_ratio = dim / tf.reduce_min(img_height_width)

        # Extract a glimpse from the image
        #offset_center = tf.random_uniform([1,2], minval=0.0 + size_ratio/2, maxval=1.0-size_ratio/2, dtype=tf.float32)
        offset_center = tf.random_uniform([1, 2],
                                          minval=0,
                                          maxval=0,
                                          dtype=tf.float32)
        offset_center = offset_center * img_height_width

        image = tf.image.extract_glimpse(image,
                                         size=[dim, dim],
                                         offsets=offset_center,
                                         centered=True,
                                         normalized=False)
        image = tf.squeeze(image, 0)

        convolved_image = tf.expand_dims(image, 0)
        psf = tf.convert_to_tensor(np.load(self.psf_file), tf.float32)
        psf /= tf.reduce_sum(psf)
        optics.attach_img(
            'gt_psf', tf.expand_dims(tf.expand_dims(tf.squeeze(psf), 0), -1))

        psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
        # psf = tf.transpose(psf, [1,2,0,3])

        pad = int(dim / 2)
        convolved_image = tf.abs(
            optics.fft_conv2d(fftpad(convolved_image, pad),
                              fftpad_psf(psf, pad),
                              adjoint=True))
        convolved_image = fftunpad(convolved_image, pad)
        convolved_image = tf.squeeze(convolved_image, axis=0)
        convolved_image /= tf.reduce_sum(convolved_image)

        image_batch, convolved_img_batch = tf.train.batch(
            [image, convolved_image],
            shapes=[[dim, dim, 1], [dim, dim, 1]],
            batch_size=batch_size,
            num_threads=4,
            capacity=4 * batch_size)

        return image_batch, convolved_img_batch
 def _get_data_loss(self, model_output, ground_truth):
     model_output = tf.cast(model_output, tf.float32)
     ground_truth = tf.cast(ground_truth, tf.float32)
     
     # model_output /= tf.reduce_max(model_output)
     ground_truth /= tf.reduce_sum(ground_truth)
     with tf.name_scope('data_loss'):
         optics.attach_img('model_output', model_output)
         optics.attach_img('ground_truth', ground_truth)
     loss = tf.reduce_mean(tf.abs(model_output - ground_truth))
     return loss
Exemple #3
0
    def _optical_conv_layer(self, input_field, hm_reg_scale, activation=None, coherent=False,
                            name='optical_conv'):
        
        with tf.variable_scope(name):
            sensordims = self.wave_resolution
            input_field = tf.cast(input_field, tf.complex128)
            
            # Zero-centered fft2 of input field
            field = optics.fftshift2d_tf(optics.transp_fft2d(optics.ifftshift2d_tf(input_field)))

            # Build a phase mask, zero-centered
            height_map_initializer=tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4)
            # height_map_initializer=None
            pm = optics.height_map_element([1,self.wave_resolution[0],self.wave_resolution[1],1],
                                             wave_lengths=self.wavelength,
                                             height_map_initializer=height_map_initializer,
                                             name='phase_mask_height',
                                             refractive_index=self.n)
            # height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale),
            
            # Get ATF and PSF
            otf = tf.ones([1,self.wave_resolution[0],self.wave_resolution[1],1])
            otf = optics.circular_aperture(otf, max_val = self.r_NA)
            otf = pm(otf)
            psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(otf)))
            psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf)
            psf /= tf.reduce_sum(psf) # conservation of energy
            psf = tf.cast(psf, tf.float32)
            optics.attach_img('psf', psf)
            
            # Get the output image
            if coherent:
                field = optics.circular_aperture(field, max_val = self.r_NA)
                field = pm(field)
                tf.summary.image('field', tf.square(tf.abs(field)))
                output_img = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(field)))
            else:
                psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
                output_img = tf.abs(optics.fft_conv2d(input_field, psf))
    
            # Apply nonlinear activation
            if activation is not None:
                output_img = activation(output_img)
                
            return output_img
Exemple #4
0
def train(params,
          summary_every=100,
          print_every=250,
          save_every=1000,
          verbose=True):
    # Unpack params
    classes = params.num_classes

    # constraint helpers
    def nonneg(input_tensor):
        return tf.square(input_tensor) if params.isNonNeg else input_tensor

    sess = tf.InteractiveSession(config=tf.ConfigProto(
        allow_soft_placement=True))

    # input placeholders
    with tf.name_scope('input'):
        # the quickdraw image size is 28 * 28
        x = tf.placeholder(tf.float32, shape=[None, 784])
        y_ = tf.placeholder(tf.int64, shape=[None, classes])
        keep_prob = tf.placeholder(tf.float32)

        x_image = tf.reshape(x, [-1, 28, 28, 1])
        # in the image dimension give four borders padding size 64
        paddings = tf.constant([[
            0,
            0,
        ], [params.padamt, params.padamt], [params.padamt, params.padamt],
                                [0, 0]])
        x_image = tf.pad(x_image, paddings)
        # x_image = tf.image.resize_nearest_neighbor(x_image, size=(dim, dim))
        tf.summary.image('input', x_image, 3)

        # if not isNonNeg and not doNonnegReg:
        #     x_image -= tf.reduce_mean(x_image)

    # nonneg regularizer
    global_step = tf.Variable(0, trainable=False)
    if params.doNonnegReg:
        # TODO: need to design start and end learning rate to get a valid decaying learning rate
        reg_scale = tf.train.polynomial_decay(0.,
                                              global_step,
                                              decay_steps=8000,
                                              end_learning_rate=10000.)
        psf_reg = optics_alt.nonneg_regularizer(reg_scale)
    else:
        psf_reg = None

    # build model
    # single tiled convolutional layer
    h_conv1 = optics_alt.tiled_conv_layer(x_image,
                                          params.tiling_factor,
                                          params.tile_size,
                                          params.kernel_size,
                                          name='h_conv1',
                                          nonneg=params.isNonNeg,
                                          regularizer=psf_reg)
    optics.attach_img("h_conv1", h_conv1)
    # each split is of size (None, 39, 156, 1)
    split_1d = tf.split(h_conv1, num_or_size_splits=4, axis=1)

    # calculating output scores (16, None, 39, 39, 1)
    h_conv_split = tf.concat([
        tf.split(split_1d[0], num_or_size_splits=4, axis=2),
        tf.split(split_1d[1], num_or_size_splits=4, axis=2),
        tf.split(split_1d[2], num_or_size_splits=4, axis=2),
        tf.split(split_1d[3], num_or_size_splits=4, axis=2)
    ], 0)
    if params.doMean:
        y_out = tf.transpose(tf.reduce_mean(h_conv_split, axis=[2, 3, 4]))
    else:
        y_out = tf.transpose(tf.reduce_max(h_conv_split, axis=[2, 3, 4]))

    tf.summary.image('output', tf.reshape(y_out, [-1, 4, 4, 1]), 3)

    # loss, train, acc
    with tf.name_scope('cross_entropy'):
        total_data_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_,
                                                                  logits=y_out)
        data_loss = tf.reduce_mean(total_data_loss)
        reg_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = tf.add(data_loss, reg_loss)
        tf.summary.scalar('data_loss', data_loss)
        tf.summary.scalar('reg_loss', reg_loss)
        tf.summary.scalar('total_loss', total_loss)

    if params.opt_type == 'ADAM':
        train_step = tf.train.AdamOptimizer(params.learning_rate).minimize(
            total_loss, global_step)
    elif params.opt_type == 'Adadelta':
        train_step = tf.train.AdadeltaOptimizer(params.learning_rate_ad,
                                                rho=.9).minimize(
                                                    total_loss, global_step)
    else:
        train_step = tf.train.MomentumOptimizer(params.learning_rate,
                                                momentum=0.5,
                                                use_nesterov=True).minimize(
                                                    total_loss, global_step)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

    losses = []

    # tensorboard setup
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(params.log_dir + '/train', sess.graph)
    test_writer = tf.summary.FileWriter(params.log_dir + '/test')

    tf.global_variables_initializer().run()

    # add ops to save and restore all the variables
    saver = tf.train.Saver(max_to_keep=2)
    save_path = os.path.join(params.log_dir, 'model.ckpt')

    for i in range(params.num_iters):
        x_train, y_train = get_feed(train=True, num_classes=classes)
        _, loss, reg_loss_graph, train_accuracy, train_summary = sess.run(
            [train_step, total_loss, reg_loss, accuracy, merged],
            feed_dict={
                x: x_train,
                y_: y_train,
                keep_prob: params.dropout
            })
        losses.append(loss)

        if i % summary_every == 0:
            train_writer.add_summary(train_summary, i)

        if i > 0 and i % save_every == 0:
            # print("Saving model...")
            saver.save(sess, save_path, global_step=i)

            # validation set
            x_valid, y_valid = get_feed(train=False, batch_size=10)
            test_summary, test_accuracy = sess.run([merged, accuracy],
                                                   feed_dict={
                                                       x: x_valid,
                                                       y_: y_valid,
                                                       keep_prob: 1.0
                                                   })
            test_writer.add_summary(test_summary, i)
            if verbose:
                print('step %d: validation acc %g' % (i, test_accuracy))

        if i % print_every == 0:
            if verbose:
                print('step %d:\t loss %g,\t reg_loss %g,\t train acc %g' %
                      (i, loss, reg_loss_graph, train_accuracy))

    test_simulation(x, y_, keep_prob, accuracy, train_accuracy, num_iter=10)
    #sess.close()

    train_writer.close()
    test_writer.close()
def train(params,
          summary_every=100,
          print_every=250,
          save_every=1000,
          verbose=True):
    # Unpack params
    wavelength = params.get('wavelength', 532e-9)
    isNonNeg = params.get('isNonNeg', False)
    numIters = params.get('numIters', 1000)
    activation = params.get('activation', tf.nn.relu)
    opt_type = params.get('opt_type', 'ADAM')

    # switches
    doMultichannelConv = params.get('doMultichannelConv', False)
    doMean = params.get('doMean', False)
    doOpticalConv = params.get('doOpticalConv', True)
    doAmplitudeMask = params.get('doAmplitudeMask', False)
    doZernike = params.get('doZernike', False)
    doNonnegReg = params.get('doNonnegReg', False)

    z_modes = params.get('z_modes', 1024)
    convdim1 = params.get('convdim1', 100)

    classes = 16
    cdim1 = params.get('cdim1', classes)

    padamt = params.get('padamt', 0)
    dim = params.get('dim', 60)

    tiling_factor = params.get('tiling_factor', 5)
    tile_size = params.get('tile_size', 56)
    kernel_size = params.get('kernel_size', 7)

    # constraint helpers
    def nonneg(input_tensor):
        return tf.square(input_tensor) if isNonNeg else input_tensor

    sess = tf.InteractiveSession(config=tf.ConfigProto(
        allow_soft_placement=True))

    # input placeholders
    with tf.name_scope('input'):
        x = tf.placeholder(tf.float32, shape=[None, 784])
        y_ = tf.placeholder(tf.int64, shape=[None, classes])
        keep_prob = tf.placeholder(tf.float32)

        x_image = tf.reshape(x, [-1, 28, 28, 1])
        paddings = tf.constant([[
            0,
            0,
        ], [padamt, padamt], [padamt, padamt], [0, 0]])
        x_image = tf.pad(x_image, paddings)
        # x_image = tf.image.resize_nearest_neighbor(x_image, size=(dim, dim))
        tf.summary.image('input', x_image, 3)

        # if not isNonNeg and not doNonnegReg:
        #     x_image -= tf.reduce_mean(x_image)

    # nonneg regularizer
    global_step = tf.Variable(0, trainable=False)
    if doNonnegReg:
        reg_scale = tf.train.polynomial_decay(0.,
                                              global_step,
                                              decay_steps=8000,
                                              end_learning_rate=10000.)
        psf_reg = optics_alt.nonneg_regularizer(reg_scale)
    else:
        psf_reg = None

    # build model
    # single tiled convolutional layer
    h_conv1 = optics_alt.tiled_conv_layer(x_image,
                                          tiling_factor,
                                          tile_size,
                                          kernel_size,
                                          name='h_conv1',
                                          nonneg=isNonNeg,
                                          regularizer=psf_reg)
    optics.attach_img("h_conv1", h_conv1)

    split_1d = tf.split(h_conv1, num_or_size_splits=4, axis=1)

    # calculating output scores
    h_conv_split = tf.concat([
        tf.split(split_1d[0], num_or_size_splits=4, axis=2),
        tf.split(split_1d[1], num_or_size_splits=4, axis=2),
        tf.split(split_1d[2], num_or_size_splits=4, axis=2),
        tf.split(split_1d[3], num_or_size_splits=4, axis=2)
    ], 0)
    if doMean:
        y_out = tf.transpose(tf.reduce_mean(h_conv_split, axis=[2, 3, 4]))
    else:
        y_out = tf.transpose(tf.reduce_max(h_conv_split, axis=[2, 3, 4]))

    tf.summary.image('output', tf.reshape(y_out, [-1, 4, 4, 1]), 3)

    # loss, train, acc
    with tf.name_scope('cross_entropy'):
        total_data_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_,
                                                                  logits=y_out)
        data_loss = tf.reduce_mean(total_data_loss)
        reg_loss = tf.reduce_sum(
            tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        total_loss = tf.add(data_loss, reg_loss)
        tf.summary.scalar('data_loss', data_loss)
        tf.summary.scalar('reg_loss', reg_loss)
        tf.summary.scalar('total_loss', total_loss)

    if opt_type == 'ADAM':
        train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
            total_loss, global_step)
    elif opt_type == 'Adadelta':
        train_step = tf.train.AdadeltaOptimizer(FLAGS.learning_rate_ad,
                                                rho=.9).minimize(
                                                    total_loss, global_step)
    else:
        train_step = tf.train.MomentumOptimizer(FLAGS.learning_rate,
                                                momentum=0.5,
                                                use_nesterov=True).minimize(
                                                    total_loss, global_step)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

    losses = []

    # tensorboard setup
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
    test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')

    tf.global_variables_initializer().run()

    # add ops to save and restore all the variables
    saver = tf.train.Saver(max_to_keep=2)
    save_path = os.path.join(FLAGS.log_dir, 'model.ckpt')

    # change to your directory
    train_data = np.load(
        '/media/data/Datasets/quickdraw/split/quickdraw16_train.npy')
    test_data = np.load(
        '/media/data/Datasets/quickdraw/split/quickdraw16_test.npy')

    def get_feed(train, batch_size=50):
        if train:
            idcs = np.random.randint(0, np.shape(train_data)[0], batch_size)
            x = train_data[idcs, :]
            y = np.zeros((batch_size, classes))
            y[np.arange(batch_size), idcs // 8000] = 1

        else:
            x = test_data
            y = np.zeros((np.shape(test_data)[0], classes))
            y[np.arange(np.shape(test_data)[0]),
              np.arange(np.shape(test_data)[0]) // 100] = 1

        return x, y

    x_test, y_test = get_feed(train=False)

    for i in range(FLAGS.num_iters):
        x_train, y_train = get_feed(train=True)
        _, loss, reg_loss_graph, train_accuracy, train_summary = sess.run(
            [train_step, total_loss, reg_loss, accuracy, merged],
            feed_dict={
                x: x_train,
                y_: y_train,
                keep_prob: FLAGS.dropout
            })
        losses.append(loss)

        if i % summary_every == 0:
            train_writer.add_summary(train_summary, i)

        if i > 0 and i % save_every == 0:
            # print("Saving model...")
            saver.save(sess, save_path, global_step=i)

            # test_summary, test_accuracy = sess.run([merged, accuracy],
            #                                        feed_dict={x: x_test, y_: y_test, keep_prob: 1.0})
            # test_writer.add_summary(test_summary, i)
            # if verbose:
            #     print('step %d: test acc %g' % (i, test_accuracy))

        if i % print_every == 0:
            if verbose:
                print('step %d:\t loss %g,\t reg_loss %g,\t train acc %g' %
                      (i, loss, reg_loss_graph, train_accuracy))

    test_batches = []
    for i in range(32):
        idx = i * 50
        batch_acc = accuracy.eval(
            feed_dict={
                x: x_test[idx:idx + 50, :],
                y_: y_test[idx:idx + 50, :],
                keep_prob: 1.0
            })
        test_batches.append(batch_acc)
    test_acc = np.mean(test_batches)

    #test_acc = accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0})
    print('final step %d, train accuracy %g, test accuracy %g' %
          (i, train_accuracy, test_acc))
    #sess.close()

    train_writer.close()
    test_writer.close()
    def _build_graph(self,
                     x_train,
                     hm_reg_scale,
                     hm_init_type='random_normal'):
        with tf.device('/device:GPU:0'):
            sensordims = (self.dim, self.dim)
            # Start with input image
            input_img = x_train / tf.reduce_sum(x_train)
            tf.summary.image('input_image', x_train)

            # fftshift(fft2(ifftshift( FIELD ))), zero-centered
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_img)))

            # Build a phase mask, zero-centered
            height_map_initializer = tf.random_uniform_initializer(
                minval=0.999e-4, maxval=1.001e-4)
            # height_map_initializer=None
            pm = optics.height_map_element(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1],
                wave_lengths=self.wave_length,
                height_map_regularizer=optics.laplace_l1_regularizer(
                    hm_reg_scale),
                height_map_initializer=height_map_initializer,
                name='phase_mask_height',
                refractive_index=self.n)

            # Get ATF and PSF
            otf = tf.ones(
                [1, self.wave_resolution[0], self.wave_resolution[1], 1])
            otf = optics.circular_aperture(otf, max_val=self.r_NA)
            otf = pm(otf)
            psf = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(otf)))
            psf = optics.Sensor(input_is_intensities=False,
                                resolution=sensordims)(psf)
            psf /= tf.reduce_sum(psf)  # sum or max?
            psf = tf.cast(psf, tf.float32)
            optics.attach_img('recon_psf', psf)

            # Get the output image
            coherent = False
            if coherent:
                field = optics.circular_aperture(field, max_val=self.r_NA)
                field = pm(field)
                tf.summary.image('field', tf.square(tf.abs(field)))
                field = optics.fftshift2d_tf(
                    optics.transp_ifft2d(optics.ifftshift2d_tf(field)))

                output_img = optics.Sensor(input_is_intensities=False,
                                           resolution=(sensordims))(field)
            else:
                psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
                output_img = tf.abs(optics.fft_conv2d(input_img, psf))
                output_img = optics.Sensor(input_is_intensities=True,
                                           resolution=(sensordims))(output_img)

            output_img /= tf.reduce_sum(output_img)  # sum or max?
            output_img = tf.cast(output_img, tf.float32)
            # output_img = tf.transpose(output_img, [1,2,0,3]) # (height, width, 1, 1)

            # Attach images to summary
            tf.summary.image('output_image', output_img)
            return output_img
Exemple #7
0
def optical_conv_layer(input_field,
                       hm_reg_scale,
                       r_NA,
                       n=1.48,
                       wavelength=532e-9,
                       activation=None,
                       coherent=False,
                       amplitude_mask=False,
                       zernike=False,
                       fourier=False,
                       binarymask=False,
                       n_modes=1024,
                       freq_range=.5,
                       initializer=None,
                       zernike_file='zernike_volume_256.npy',
                       binary_mask_np=None,
                       name='optical_conv'):

    dims = input_field.get_shape().as_list()
    with tf.variable_scope(name):

        if initializer is None:
            initializer = tf.random_uniform_initializer(minval=0.999e-4,
                                                        maxval=1.001e-4)

        if amplitude_mask:
            # Build an amplitude mask, zero-centered
            # amplitude_map_initializer=tf.random_uniform_initializer(minval=0.1e-3, maxval=.1e-2)
            mask = optics.amplitude_map_element(
                [1, dims[1], dims[2], 1],
                r_NA,
                amplitude_map_initializer=initializer,
                name='amplitude_mask')

        else:
            # Build a phase mask, zero-centered
            if zernike:
                zernike_modes = np.load(zernike_file)[:n_modes, :, :]
                zernike_modes = tf.expand_dims(zernike_modes, -1)
                zernike_modes = tf.image.resize_images(zernike_modes,
                                                       size=(dims[1], dims[2]))
                zernike_modes = tf.squeeze(zernike_modes, -1)
                mask = optics.zernike_element(zernike_modes,
                                              'zernike_element',
                                              wavelength,
                                              n,
                                              r_NA,
                                              zernike_initializer=initializer)
            elif fourier:
                mask = optics.fourier_element(
                    [1, dims[1], dims[2], 1],
                    'fourier_element',
                    wave_lengths=wavelength,
                    refractive_index=n,
                    frequency_range=freq_range,
                    height_map_regularizer=None,
                    height_tolerance=None,  # Default height tolerance is 2 nm.
                )

            else:
                # height_map_initializer=None
                mask = optics.height_map_element(
                    [1, dims[1], dims[2], 1],
                    wave_lengths=wavelength,
                    height_map_initializer=initializer,
                    #height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale),
                    name='phase_mask_height',
                    refractive_index=n)

        # Get ATF and PSF
        atf = tf.ones([1, dims[1], dims[2], 1])
        #zernike=True
        if zernike:
            atf = optics.circular_aperture(atf, max_val=r_NA)
        atf = mask(atf)

        # apply any additional binary amplitude mask [1, dim, dim, 1]
        if binarymask:
            binary_mask = tf.convert_to_tensor(binary_mask_np,
                                               dtype=tf.float32)
            binary_mask = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1)
            # optics.attach_img('binary_mask', binary_mask)
            atf = atf * tf.cast(binary_mask, tf.complex128)
            optics.attach_img('atf', tf.abs(binary_mask))

        # psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psfc = optics.fftshift2d_tf(
            optics.transp_ifft2d(optics.ifftshift2d_tf(atf)))
        psf = optics.Sensor(input_is_intensities=False,
                            resolution=(dims[1], dims[2]))(psfc)
        psf /= tf.reduce_sum(psf)  # conservation of energy
        psf = tf.cast(psf, tf.float32)
        optics.attach_summaries('psf', psf, True, True)

        # Get the output image
        if coherent:
            input_field = tf.cast(input_field, tf.complex128)
            # Zero-centered fft2 of input field
            field = optics.fftshift2d_tf(
                optics.transp_fft2d(optics.ifftshift2d_tf(input_field)))
            # field = optics.circular_aperture(field, max_val = r_NA)
            field = atf * field
            tf.summary.image('field', tf.log(tf.square(tf.abs(field))))
            output_img = optics.fftshift2d_tf(
                optics.transp_ifft2d(optics.ifftshift2d_tf(field)))
            output_img = optics.Sensor(input_is_intensities=False,
                                       resolution=(dims[1],
                                                   dims[2]))(output_img)
            # does this need padding as well?

            # psfc = tf.expand_dims(tf.expand_dims(tf.squeeze(psfc), -1), -1)
            # padamt = int(dims[1]/2)
            # output_img = optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psfc, padamt), adjoint=False)
            # output_img = fftunpad(output_img, padamt)
            # output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1],dims[2]))(output_img)

        else:
            psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1)
            # psf_flip = tf.reverse(psf,[0,1])
            # output_img = conv2d(input_field, psf)

            padamt = int(dims[1] / 2)
            output_img = tf.abs(
                optics.fft_conv2d(fftpad(input_field, padamt),
                                  fftpad_psf(psf, padamt),
                                  adjoint=False))
            output_img = fftunpad(output_img, padamt)

        # Apply nonlinear activation
        if activation is not None:
            output_img = activation(output_img)

        return output_img