def _get_training_queue(self, batch_size, num_threads=4): dim = self.dim file_list = tf.matching_files( '/media/data/onn/quickdraw16_192/im_*.png') filename_queue = tf.train.string_input_producer(file_list) image_reader = tf.WholeFileReader() _, image_file = image_reader.read(filename_queue) image = tf.image.decode_png(image_file, channels=1, dtype=tf.uint8) image = tf.cast(image, tf.float32) # Shape [height, width, 1] image = tf.expand_dims(image, 0) image /= 255. # Get the ratio of the patch size to the smallest side of the image img_height_width = tf.cast(tf.shape(image)[1:3], tf.float32) size_ratio = dim / tf.reduce_min(img_height_width) # Extract a glimpse from the image #offset_center = tf.random_uniform([1,2], minval=0.0 + size_ratio/2, maxval=1.0-size_ratio/2, dtype=tf.float32) offset_center = tf.random_uniform([1, 2], minval=0, maxval=0, dtype=tf.float32) offset_center = offset_center * img_height_width image = tf.image.extract_glimpse(image, size=[dim, dim], offsets=offset_center, centered=True, normalized=False) image = tf.squeeze(image, 0) convolved_image = tf.expand_dims(image, 0) psf = tf.convert_to_tensor(np.load(self.psf_file), tf.float32) psf /= tf.reduce_sum(psf) optics.attach_img( 'gt_psf', tf.expand_dims(tf.expand_dims(tf.squeeze(psf), 0), -1)) psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) # psf = tf.transpose(psf, [1,2,0,3]) pad = int(dim / 2) convolved_image = tf.abs( optics.fft_conv2d(fftpad(convolved_image, pad), fftpad_psf(psf, pad), adjoint=True)) convolved_image = fftunpad(convolved_image, pad) convolved_image = tf.squeeze(convolved_image, axis=0) convolved_image /= tf.reduce_sum(convolved_image) image_batch, convolved_img_batch = tf.train.batch( [image, convolved_image], shapes=[[dim, dim, 1], [dim, dim, 1]], batch_size=batch_size, num_threads=4, capacity=4 * batch_size) return image_batch, convolved_img_batch
def _get_data_loss(self, model_output, ground_truth): model_output = tf.cast(model_output, tf.float32) ground_truth = tf.cast(ground_truth, tf.float32) # model_output /= tf.reduce_max(model_output) ground_truth /= tf.reduce_sum(ground_truth) with tf.name_scope('data_loss'): optics.attach_img('model_output', model_output) optics.attach_img('ground_truth', ground_truth) loss = tf.reduce_mean(tf.abs(model_output - ground_truth)) return loss
def _optical_conv_layer(self, input_field, hm_reg_scale, activation=None, coherent=False, name='optical_conv'): with tf.variable_scope(name): sensordims = self.wave_resolution input_field = tf.cast(input_field, tf.complex128) # Zero-centered fft2 of input field field = optics.fftshift2d_tf(optics.transp_fft2d(optics.ifftshift2d_tf(input_field))) # Build a phase mask, zero-centered height_map_initializer=tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4) # height_map_initializer=None pm = optics.height_map_element([1,self.wave_resolution[0],self.wave_resolution[1],1], wave_lengths=self.wavelength, height_map_initializer=height_map_initializer, name='phase_mask_height', refractive_index=self.n) # height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale), # Get ATF and PSF otf = tf.ones([1,self.wave_resolution[0],self.wave_resolution[1],1]) otf = optics.circular_aperture(otf, max_val = self.r_NA) otf = pm(otf) psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(otf))) psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf) psf /= tf.reduce_sum(psf) # conservation of energy psf = tf.cast(psf, tf.float32) optics.attach_img('psf', psf) # Get the output image if coherent: field = optics.circular_aperture(field, max_val = self.r_NA) field = pm(field) tf.summary.image('field', tf.square(tf.abs(field))) output_img = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(field))) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) output_img = tf.abs(optics.fft_conv2d(input_field, psf)) # Apply nonlinear activation if activation is not None: output_img = activation(output_img) return output_img
def train(params, summary_every=100, print_every=250, save_every=1000, verbose=True): # Unpack params classes = params.num_classes # constraint helpers def nonneg(input_tensor): return tf.square(input_tensor) if params.isNonNeg else input_tensor sess = tf.InteractiveSession(config=tf.ConfigProto( allow_soft_placement=True)) # input placeholders with tf.name_scope('input'): # the quickdraw image size is 28 * 28 x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.int64, shape=[None, classes]) keep_prob = tf.placeholder(tf.float32) x_image = tf.reshape(x, [-1, 28, 28, 1]) # in the image dimension give four borders padding size 64 paddings = tf.constant([[ 0, 0, ], [params.padamt, params.padamt], [params.padamt, params.padamt], [0, 0]]) x_image = tf.pad(x_image, paddings) # x_image = tf.image.resize_nearest_neighbor(x_image, size=(dim, dim)) tf.summary.image('input', x_image, 3) # if not isNonNeg and not doNonnegReg: # x_image -= tf.reduce_mean(x_image) # nonneg regularizer global_step = tf.Variable(0, trainable=False) if params.doNonnegReg: # TODO: need to design start and end learning rate to get a valid decaying learning rate reg_scale = tf.train.polynomial_decay(0., global_step, decay_steps=8000, end_learning_rate=10000.) psf_reg = optics_alt.nonneg_regularizer(reg_scale) else: psf_reg = None # build model # single tiled convolutional layer h_conv1 = optics_alt.tiled_conv_layer(x_image, params.tiling_factor, params.tile_size, params.kernel_size, name='h_conv1', nonneg=params.isNonNeg, regularizer=psf_reg) optics.attach_img("h_conv1", h_conv1) # each split is of size (None, 39, 156, 1) split_1d = tf.split(h_conv1, num_or_size_splits=4, axis=1) # calculating output scores (16, None, 39, 39, 1) h_conv_split = tf.concat([ tf.split(split_1d[0], num_or_size_splits=4, axis=2), tf.split(split_1d[1], num_or_size_splits=4, axis=2), tf.split(split_1d[2], num_or_size_splits=4, axis=2), tf.split(split_1d[3], num_or_size_splits=4, axis=2) ], 0) if params.doMean: y_out = tf.transpose(tf.reduce_mean(h_conv_split, axis=[2, 3, 4])) else: y_out = tf.transpose(tf.reduce_max(h_conv_split, axis=[2, 3, 4])) tf.summary.image('output', tf.reshape(y_out, [-1, 4, 4, 1]), 3) # loss, train, acc with tf.name_scope('cross_entropy'): total_data_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_out) data_loss = tf.reduce_mean(total_data_loss) reg_loss = tf.reduce_sum( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) total_loss = tf.add(data_loss, reg_loss) tf.summary.scalar('data_loss', data_loss) tf.summary.scalar('reg_loss', reg_loss) tf.summary.scalar('total_loss', total_loss) if params.opt_type == 'ADAM': train_step = tf.train.AdamOptimizer(params.learning_rate).minimize( total_loss, global_step) elif params.opt_type == 'Adadelta': train_step = tf.train.AdadeltaOptimizer(params.learning_rate_ad, rho=.9).minimize( total_loss, global_step) else: train_step = tf.train.MomentumOptimizer(params.learning_rate, momentum=0.5, use_nesterov=True).minimize( total_loss, global_step) with tf.name_scope('accuracy'): correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) losses = [] # tensorboard setup merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(params.log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(params.log_dir + '/test') tf.global_variables_initializer().run() # add ops to save and restore all the variables saver = tf.train.Saver(max_to_keep=2) save_path = os.path.join(params.log_dir, 'model.ckpt') for i in range(params.num_iters): x_train, y_train = get_feed(train=True, num_classes=classes) _, loss, reg_loss_graph, train_accuracy, train_summary = sess.run( [train_step, total_loss, reg_loss, accuracy, merged], feed_dict={ x: x_train, y_: y_train, keep_prob: params.dropout }) losses.append(loss) if i % summary_every == 0: train_writer.add_summary(train_summary, i) if i > 0 and i % save_every == 0: # print("Saving model...") saver.save(sess, save_path, global_step=i) # validation set x_valid, y_valid = get_feed(train=False, batch_size=10) test_summary, test_accuracy = sess.run([merged, accuracy], feed_dict={ x: x_valid, y_: y_valid, keep_prob: 1.0 }) test_writer.add_summary(test_summary, i) if verbose: print('step %d: validation acc %g' % (i, test_accuracy)) if i % print_every == 0: if verbose: print('step %d:\t loss %g,\t reg_loss %g,\t train acc %g' % (i, loss, reg_loss_graph, train_accuracy)) test_simulation(x, y_, keep_prob, accuracy, train_accuracy, num_iter=10) #sess.close() train_writer.close() test_writer.close()
def train(params, summary_every=100, print_every=250, save_every=1000, verbose=True): # Unpack params wavelength = params.get('wavelength', 532e-9) isNonNeg = params.get('isNonNeg', False) numIters = params.get('numIters', 1000) activation = params.get('activation', tf.nn.relu) opt_type = params.get('opt_type', 'ADAM') # switches doMultichannelConv = params.get('doMultichannelConv', False) doMean = params.get('doMean', False) doOpticalConv = params.get('doOpticalConv', True) doAmplitudeMask = params.get('doAmplitudeMask', False) doZernike = params.get('doZernike', False) doNonnegReg = params.get('doNonnegReg', False) z_modes = params.get('z_modes', 1024) convdim1 = params.get('convdim1', 100) classes = 16 cdim1 = params.get('cdim1', classes) padamt = params.get('padamt', 0) dim = params.get('dim', 60) tiling_factor = params.get('tiling_factor', 5) tile_size = params.get('tile_size', 56) kernel_size = params.get('kernel_size', 7) # constraint helpers def nonneg(input_tensor): return tf.square(input_tensor) if isNonNeg else input_tensor sess = tf.InteractiveSession(config=tf.ConfigProto( allow_soft_placement=True)) # input placeholders with tf.name_scope('input'): x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.int64, shape=[None, classes]) keep_prob = tf.placeholder(tf.float32) x_image = tf.reshape(x, [-1, 28, 28, 1]) paddings = tf.constant([[ 0, 0, ], [padamt, padamt], [padamt, padamt], [0, 0]]) x_image = tf.pad(x_image, paddings) # x_image = tf.image.resize_nearest_neighbor(x_image, size=(dim, dim)) tf.summary.image('input', x_image, 3) # if not isNonNeg and not doNonnegReg: # x_image -= tf.reduce_mean(x_image) # nonneg regularizer global_step = tf.Variable(0, trainable=False) if doNonnegReg: reg_scale = tf.train.polynomial_decay(0., global_step, decay_steps=8000, end_learning_rate=10000.) psf_reg = optics_alt.nonneg_regularizer(reg_scale) else: psf_reg = None # build model # single tiled convolutional layer h_conv1 = optics_alt.tiled_conv_layer(x_image, tiling_factor, tile_size, kernel_size, name='h_conv1', nonneg=isNonNeg, regularizer=psf_reg) optics.attach_img("h_conv1", h_conv1) split_1d = tf.split(h_conv1, num_or_size_splits=4, axis=1) # calculating output scores h_conv_split = tf.concat([ tf.split(split_1d[0], num_or_size_splits=4, axis=2), tf.split(split_1d[1], num_or_size_splits=4, axis=2), tf.split(split_1d[2], num_or_size_splits=4, axis=2), tf.split(split_1d[3], num_or_size_splits=4, axis=2) ], 0) if doMean: y_out = tf.transpose(tf.reduce_mean(h_conv_split, axis=[2, 3, 4])) else: y_out = tf.transpose(tf.reduce_max(h_conv_split, axis=[2, 3, 4])) tf.summary.image('output', tf.reshape(y_out, [-1, 4, 4, 1]), 3) # loss, train, acc with tf.name_scope('cross_entropy'): total_data_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_out) data_loss = tf.reduce_mean(total_data_loss) reg_loss = tf.reduce_sum( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) total_loss = tf.add(data_loss, reg_loss) tf.summary.scalar('data_loss', data_loss) tf.summary.scalar('reg_loss', reg_loss) tf.summary.scalar('total_loss', total_loss) if opt_type == 'ADAM': train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize( total_loss, global_step) elif opt_type == 'Adadelta': train_step = tf.train.AdadeltaOptimizer(FLAGS.learning_rate_ad, rho=.9).minimize( total_loss, global_step) else: train_step = tf.train.MomentumOptimizer(FLAGS.learning_rate, momentum=0.5, use_nesterov=True).minimize( total_loss, global_step) with tf.name_scope('accuracy'): correct_prediction = tf.equal(tf.argmax(y_out, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) losses = [] # tensorboard setup merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test') tf.global_variables_initializer().run() # add ops to save and restore all the variables saver = tf.train.Saver(max_to_keep=2) save_path = os.path.join(FLAGS.log_dir, 'model.ckpt') # change to your directory train_data = np.load( '/media/data/Datasets/quickdraw/split/quickdraw16_train.npy') test_data = np.load( '/media/data/Datasets/quickdraw/split/quickdraw16_test.npy') def get_feed(train, batch_size=50): if train: idcs = np.random.randint(0, np.shape(train_data)[0], batch_size) x = train_data[idcs, :] y = np.zeros((batch_size, classes)) y[np.arange(batch_size), idcs // 8000] = 1 else: x = test_data y = np.zeros((np.shape(test_data)[0], classes)) y[np.arange(np.shape(test_data)[0]), np.arange(np.shape(test_data)[0]) // 100] = 1 return x, y x_test, y_test = get_feed(train=False) for i in range(FLAGS.num_iters): x_train, y_train = get_feed(train=True) _, loss, reg_loss_graph, train_accuracy, train_summary = sess.run( [train_step, total_loss, reg_loss, accuracy, merged], feed_dict={ x: x_train, y_: y_train, keep_prob: FLAGS.dropout }) losses.append(loss) if i % summary_every == 0: train_writer.add_summary(train_summary, i) if i > 0 and i % save_every == 0: # print("Saving model...") saver.save(sess, save_path, global_step=i) # test_summary, test_accuracy = sess.run([merged, accuracy], # feed_dict={x: x_test, y_: y_test, keep_prob: 1.0}) # test_writer.add_summary(test_summary, i) # if verbose: # print('step %d: test acc %g' % (i, test_accuracy)) if i % print_every == 0: if verbose: print('step %d:\t loss %g,\t reg_loss %g,\t train acc %g' % (i, loss, reg_loss_graph, train_accuracy)) test_batches = [] for i in range(32): idx = i * 50 batch_acc = accuracy.eval( feed_dict={ x: x_test[idx:idx + 50, :], y_: y_test[idx:idx + 50, :], keep_prob: 1.0 }) test_batches.append(batch_acc) test_acc = np.mean(test_batches) #test_acc = accuracy.eval(feed_dict={x: x_test, y_: y_test, keep_prob: 1.0}) print('final step %d, train accuracy %g, test accuracy %g' % (i, train_accuracy, test_acc)) #sess.close() train_writer.close() test_writer.close()
def _build_graph(self, x_train, hm_reg_scale, hm_init_type='random_normal'): with tf.device('/device:GPU:0'): sensordims = (self.dim, self.dim) # Start with input image input_img = x_train / tf.reduce_sum(x_train) tf.summary.image('input_image', x_train) # fftshift(fft2(ifftshift( FIELD ))), zero-centered field = optics.fftshift2d_tf( optics.transp_fft2d(optics.ifftshift2d_tf(input_img))) # Build a phase mask, zero-centered height_map_initializer = tf.random_uniform_initializer( minval=0.999e-4, maxval=1.001e-4) # height_map_initializer=None pm = optics.height_map_element( [1, self.wave_resolution[0], self.wave_resolution[1], 1], wave_lengths=self.wave_length, height_map_regularizer=optics.laplace_l1_regularizer( hm_reg_scale), height_map_initializer=height_map_initializer, name='phase_mask_height', refractive_index=self.n) # Get ATF and PSF otf = tf.ones( [1, self.wave_resolution[0], self.wave_resolution[1], 1]) otf = optics.circular_aperture(otf, max_val=self.r_NA) otf = pm(otf) psf = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(otf))) psf = optics.Sensor(input_is_intensities=False, resolution=sensordims)(psf) psf /= tf.reduce_sum(psf) # sum or max? psf = tf.cast(psf, tf.float32) optics.attach_img('recon_psf', psf) # Get the output image coherent = False if coherent: field = optics.circular_aperture(field, max_val=self.r_NA) field = pm(field) tf.summary.image('field', tf.square(tf.abs(field))) field = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(field))) output_img = optics.Sensor(input_is_intensities=False, resolution=(sensordims))(field) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) output_img = tf.abs(optics.fft_conv2d(input_img, psf)) output_img = optics.Sensor(input_is_intensities=True, resolution=(sensordims))(output_img) output_img /= tf.reduce_sum(output_img) # sum or max? output_img = tf.cast(output_img, tf.float32) # output_img = tf.transpose(output_img, [1,2,0,3]) # (height, width, 1, 1) # Attach images to summary tf.summary.image('output_image', output_img) return output_img
def optical_conv_layer(input_field, hm_reg_scale, r_NA, n=1.48, wavelength=532e-9, activation=None, coherent=False, amplitude_mask=False, zernike=False, fourier=False, binarymask=False, n_modes=1024, freq_range=.5, initializer=None, zernike_file='zernike_volume_256.npy', binary_mask_np=None, name='optical_conv'): dims = input_field.get_shape().as_list() with tf.variable_scope(name): if initializer is None: initializer = tf.random_uniform_initializer(minval=0.999e-4, maxval=1.001e-4) if amplitude_mask: # Build an amplitude mask, zero-centered # amplitude_map_initializer=tf.random_uniform_initializer(minval=0.1e-3, maxval=.1e-2) mask = optics.amplitude_map_element( [1, dims[1], dims[2], 1], r_NA, amplitude_map_initializer=initializer, name='amplitude_mask') else: # Build a phase mask, zero-centered if zernike: zernike_modes = np.load(zernike_file)[:n_modes, :, :] zernike_modes = tf.expand_dims(zernike_modes, -1) zernike_modes = tf.image.resize_images(zernike_modes, size=(dims[1], dims[2])) zernike_modes = tf.squeeze(zernike_modes, -1) mask = optics.zernike_element(zernike_modes, 'zernike_element', wavelength, n, r_NA, zernike_initializer=initializer) elif fourier: mask = optics.fourier_element( [1, dims[1], dims[2], 1], 'fourier_element', wave_lengths=wavelength, refractive_index=n, frequency_range=freq_range, height_map_regularizer=None, height_tolerance=None, # Default height tolerance is 2 nm. ) else: # height_map_initializer=None mask = optics.height_map_element( [1, dims[1], dims[2], 1], wave_lengths=wavelength, height_map_initializer=initializer, #height_map_regularizer=optics.laplace_l1_regularizer(hm_reg_scale), name='phase_mask_height', refractive_index=n) # Get ATF and PSF atf = tf.ones([1, dims[1], dims[2], 1]) #zernike=True if zernike: atf = optics.circular_aperture(atf, max_val=r_NA) atf = mask(atf) # apply any additional binary amplitude mask [1, dim, dim, 1] if binarymask: binary_mask = tf.convert_to_tensor(binary_mask_np, dtype=tf.float32) binary_mask = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1) # optics.attach_img('binary_mask', binary_mask) atf = atf * tf.cast(binary_mask, tf.complex128) optics.attach_img('atf', tf.abs(binary_mask)) # psf = optics.fftshift2d_tf(optics.transp_ifft2d(optics.ifftshift2d_tf(atf))) psfc = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(atf))) psf = optics.Sensor(input_is_intensities=False, resolution=(dims[1], dims[2]))(psfc) psf /= tf.reduce_sum(psf) # conservation of energy psf = tf.cast(psf, tf.float32) optics.attach_summaries('psf', psf, True, True) # Get the output image if coherent: input_field = tf.cast(input_field, tf.complex128) # Zero-centered fft2 of input field field = optics.fftshift2d_tf( optics.transp_fft2d(optics.ifftshift2d_tf(input_field))) # field = optics.circular_aperture(field, max_val = r_NA) field = atf * field tf.summary.image('field', tf.log(tf.square(tf.abs(field)))) output_img = optics.fftshift2d_tf( optics.transp_ifft2d(optics.ifftshift2d_tf(field))) output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1], dims[2]))(output_img) # does this need padding as well? # psfc = tf.expand_dims(tf.expand_dims(tf.squeeze(psfc), -1), -1) # padamt = int(dims[1]/2) # output_img = optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psfc, padamt), adjoint=False) # output_img = fftunpad(output_img, padamt) # output_img = optics.Sensor(input_is_intensities=False, resolution=(dims[1],dims[2]))(output_img) else: psf = tf.expand_dims(tf.expand_dims(tf.squeeze(psf), -1), -1) # psf_flip = tf.reverse(psf,[0,1]) # output_img = conv2d(input_field, psf) padamt = int(dims[1] / 2) output_img = tf.abs( optics.fft_conv2d(fftpad(input_field, padamt), fftpad_psf(psf, padamt), adjoint=False)) output_img = fftunpad(output_img, padamt) # Apply nonlinear activation if activation is not None: output_img = activation(output_img) return output_img