def scale(self, x_obj, w_scale, w_trans): # removes the last dimension obj = tf.squeeze(x_obj) w_scale = tf.expand_dims(w_scale, -1) w_trans_x = tf.expand_dims(w_trans[:, 0], -1) w_trans_y = tf.expand_dims(w_trans[:, 1], -1) w_trans_z = tf.expand_dims(w_trans[:, 2], -1) # !!! note that in their implementation for images, y and x are swiiched, # in out implementation it as well: [batch, y, x, z] # transform the x and y dim first s_xy = tf.matmul(w_scale, tf.constant([[1, 0, 0, 0, 1, 0]], dtype=tf.float32)) t_xy = tf.matmul( w_trans_x, tf.constant([[0, 0, 1, 0, 0, 0]], dtype=tf.float32)) + tf.matmul( w_trans_y, tf.constant([[0, 0, 0, 0, 0, 1]], dtype=tf.float32)) T_xy = s_xy + t_xy transformed_xy = stn(obj, T_xy) # reshape the obj so that it starts with z dim: [batch, x, z, y] transposed_transformed_xy = tf.transpose(transformed_xy, [0, 2, 3, 1]) # transform the z dim s_zx = tf.matmul( w_scale, tf.constant( [[1, 0, 0, 0, 0, 0]], dtype=tf.float32)) + tf.matmul( tf.ones([self.batch_size, 1]), tf.constant([[0, 0, 0, 0, 1, 0]], dtype=tf.float32)) t_zx = tf.matmul(w_trans_z, tf.constant([[0, 0, 1, 0, 0, 0]], dtype=tf.float32)) T_zx = s_zx + t_zx transformed_zxy = stn(transposed_transformed_xy, T_zx) # reshape to the original order: [batch, y, x, z] transformed_xyz = tf.transpose(transformed_zxy, [0, 3, 1, 2]) # add the last dimension back transformed_xyz = tf.expand_dims(transformed_xyz, -1) print(transformed_xyz) return transformed_xyz
def build_convnet(): # localization network conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc') pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc') conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc') pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc') pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc) fc1_loc = Dense(pool2_loc_flat, pool2_loc_size, 2048, use_relu=False, name='fc1_loc') fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc') fc3_loc = Dense(fc2_loc, 512, 6, use_relu=False, trans=True, name='fc3_loc') print('fc3_loc: {}'.format(fc3_loc.get_shape())) # spatial transformer h_trans = stn(X, fc3_loc) print('h_trans: {}'.format(h_trans.get_shape())) # convnet conv1 = Conv2D(X, 1, 5, 32, name='conv1') bn1 = BatchNormalization(conv1, phase, name='bn1') pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1') conv2 = Conv2D(pool1, 32, 5, 64, name='conv2') bn2 = BatchNormalization(conv2, phase, name='bn2') pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2') conv3 = Conv2D(pool2, 64, 3, 128, name='conv3') bn3 = BatchNormalization(conv3, phase, name='bn3') pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3') pool3_flat, pool3_size = Flatten(pool3) fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1') bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4') fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2') bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5') logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False) return h_trans, logits
def focal_loss(target_tensor, theta, org, weights=None, alpha=0.25, gamma=2): r"""Compute focal loss for predictions. Multi-labels Focal loss formula: FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p) ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor. Args: prediction_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing one-hot encoded classification targets weights: A float tensor of shape [batch_size, num_anchors] alpha: A scalar tensor for focal loss alpha hyper-parameter gamma: A scalar tensor for focal loss gamma hyper-parameter Returns: loss: A (scalar) tensor representing the value of the loss function """ prediction_tensor = stn(org, theta) prediction_tensor = tf.to_int32(prediction_tensor > 0.5) prediction_tensor = tf.one_hot(prediction_tensor, depth=2) prediction_tensor = tf.dtypes.cast(prediction_tensor, tf.float32) # target_tensor = tf.convert_to_tensor(target_tensor, tf.int32) target_tensor = tf.dtypes.cast(target_tensor, tf.int32) target_tensor = tf.one_hot(target_tensor, depth=2) target_tensor = tf.dtypes.cast(target_tensor, tf.float32) prediction_tensor = tf.convert_to_tensor(prediction_tensor, tf.float32) target_tensor = tf.convert_to_tensor(target_tensor, tf.float32) print("Target tensor shape", target_tensor.get_shape().as_list()) print("Prediction tensor shape", prediction_tensor.get_shape().as_list()) sigmoid_p = tf.nn.sigmoid(prediction_tensor) zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype) # For poitive prediction, only need consider front part loss, back part is 0; # target_tensor > zeros <=> z=1, so poitive coefficient = z - p. pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - sigmoid_p, zeros) # For negative prediction, only need consider back part loss, front part is 0; # target_tensor > zeros <=> z=1, so negative coefficient = 0. neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p) per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \ - (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0)) return tf.reduce_sum(per_entry_cross_ent)
def focal_loss_(labels, theta, org, gamma=2.0, alpha=4.0): logits = stn(org, theta) # logits = (0.5 > logits).float() * 1 logits = tf.cast(logits + 0.5, tf.float32) # logits = tf.one_hot(logits, depth=2) epsilon = 1.e-9 labels = tf.convert_to_tensor(labels, tf.float32) logits = tf.convert_to_tensor(logits, tf.float32) logits = tf.nn.softmax(logits, dim=-1) model_out = tf.add(logits, epsilon) ce = tf.multiply(labels, -tf.log(model_out)) weight = tf.multiply(labels, tf.pow(tf.subtract(1., model_out), gamma)) fl = tf.multiply(alpha, tf.multiply(weight, ce)) reduced_fl = tf.reduce_max(fl, axis=1) return reduced_fl
def build_convnet(): # localization network conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc') pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc') conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc') pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc') pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc) fc1_loc = Dense(pool2_loc_flat, pool2_loc_size, 2048, use_relu=False, name='fc1_loc') fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc') fc3_loc = Dense(fc2_loc, 512, 6, use_relu=False, trans=True, name='fc3_loc') # spatial transformer h_trans = stn(X, fc3_loc) # convnet conv1 = Conv2D(X, 1, 5, 32, name='conv1') bn1 = BatchNormalization(conv1, phase, name='bn1') pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1') conv2 = Conv2D(pool1, 32, 5, 64, name='conv2') bn2 = BatchNormalization(conv2, phase, name='bn2') pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2') conv3 = Conv2D(pool2, 64, 3, 128, name='conv3') bn3 = BatchNormalization(conv3, phase, name='bn3') pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3') pool3_flat, pool3_size = Flatten(pool3) fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1') bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4') fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2') bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5') logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False) return h_trans, logits
def classification_loss(labels, theta, org): logits = stn(org, theta) n_class = 1 flat_logits = tf.reshape(logits, [-1]) flat_labels = tf.reshape(labels, [-1]) # print(tf.shape(flat_logits)) # print(tf.shape(flat_labels)) loss = tf.losses.mean_squared_error(flat_labels, flat_logits) # flat_logits = tf.multiply(flat_logits, 255.0) # flat_labels = tf.multiply(flat_labels, 255.0) # flat_logits = tf.dtypes.cast(flat_logits, dtype=tf.int32) # flat_labels = tf.dtypes.cast(flat_labels, dtype=tf.int32) # accuracy, update_op = tf.metrics.accuracy(labels=flat_labels[0], # predictions=flat_logits[0]) # return loss, accuracy return loss
def main(): # load the data print("Loading the data...") X_train, y_train, X_test, y_test, X_valid, y_valid = load_data(root_dir) # sanity check dimensions # print("Train: {}".format(X_train.shape)) # print("Test: {}".format(X_test.shape)) # print("Valid: {}".format(X_valid.shape)) # let's view a small sample if VIEW: mask = np.arange(9) gd_truth = np.argmax(y_train[mask], axis=1) sample = X_train.squeeze()[mask] plot_images(sample, gd_truth) if SAMPLE: mask = np.arange(500) X_train = X_train[mask] y_train = y_train[mask] num_train = X_train.shape[0] gd_truth = np.argmax(y_train, axis=1) # let's check the frequencies of each class # plt.hist(gd_truth, bins=num_classes) # plt.title("Ground Truth Labels") # plt.xlabel("Class") # plt.ylabel("Frequency") # plt.show() print("Building ConvNet...") conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc') pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc') conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc') pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc') pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc) fc1_loc = Dense(pool2_loc_flat, pool2_loc_size, 2048, use_relu=False, name='fc1_loc') fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc') fc3_loc = Dense(fc2_loc, 512, 6, use_relu=False, trans=True, name='fc3_loc') # spatial transformer h_trans = stn(X, fc3_loc, out_dims=(H, W)) # convnet conv1 = Conv2D(h_trans, 1, 5, 32, name='conv1') bn1 = BatchNormalization(conv1, phase, name='bn1') pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1') conv2 = Conv2D(pool1, 32, 5, 64, name='conv2') bn2 = BatchNormalization(conv2, phase, name='bn2') pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2') conv3 = Conv2D(pool2, 64, 3, 128, name='conv3') bn3 = BatchNormalization(conv3, phase, name='bn3') pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3') pool3_flat, pool3_size = Flatten(pool3) fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1') bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4') fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2') bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5') logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False) # define cost function cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y) loss = tf.reduce_mean(cross_entropy) # define optimizer global_step = tf.Variable(initial_value=0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step) # define accuracy correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # define saver object for storing and retrieving checkpoints saver = tf.train.Saver() if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, 'best_validation') # path for the checkpoint file total_batch = int(np.ceil(num_train / float(batch_size))) num_iterations = num_epochs * total_batch global best_validation_accuracy global last_improvement global require_improvement # create summary for loss and acc tf.summary.scalar('train_loss', loss) tf.summary.scalar('train_accuracy', accuracy) summary_op = tf.summary.merge_all() if not os.path.exists(logs_dir): os.makedirs(logs_dir) logs_path = os.path.join(logs_dir, 'cluttered_mnist/') if not os.path.exists(vis_path): os.makedirs(vis_path) with tf.Session() as sess: if RESTORE: # restore checkpoint if it exists try: print("Trying to restore last checkpoint...") last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=save_dir) saver.restore(sess, save_path=last_chk_path) print("Restored checkpoint from: ", last_chk_path) except: print("Failed to restore checkpoint. Initializing variables instead.") sess.run(tf.global_variables_initializer()) else: sess.run(tf.global_variables_initializer()) # for tensorboard viewing writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph()) # for visualization purposes fig = plt.figure() if MODE == 'train': tic = time.time() print("Training on {} samples, validating on {} samples".format(len(X_train), len(X_valid))) iter_per_epoch, batch_indices = generate_batch_indices(X_train) batch_indices = batch_indices * num_epochs epoch_num = 0 for i in range(num_iterations): # grab the batch index from list idx = batch_indices[i] mask = np.arange(idx[0], idx[1]) # slice into batches batch_X_train, batch_y_train = X_train[mask], y_train[mask] # create feed dict train_feed_dict = {X: batch_X_train, y: batch_y_train, phase: True} i_global, _ = sess.run([global_step, optimizer], feed_dict=train_feed_dict) if (i_global % display_step == 0) or (i == num_iterations - 1): # calculate loss and accuracy on training batch train_batch_loss, train_batch_acc, train_summary = sess.run([loss, accuracy, summary_op], feed_dict=train_feed_dict) writer.add_summary(train_summary, i_global) # calculate loss and accuracy on validation batch valid_batch_loss, valid_batch_acc = validate_acc_loss(sess, loss, accuracy, X_valid, y_valid) # check to see if there's an improvement improved_str = '' if valid_batch_acc > best_validation_accuracy: best_validation_accuracy = valid_batch_acc last_improvement = i_global saver.save(sess=sess, save_path=save_path+str(best_validation_accuracy), global_step=i_global) improved_str = '*' print("Iter: {}/{} - loss: {:.4f} - acc: {:.4f} - val_loss: {:.4f} - val_acc: {:.4f} - {}".format(i_global, num_iterations, train_batch_loss, train_batch_acc, valid_batch_loss, valid_batch_acc, improved_str)) # if no improvement in a while, stop training if i_global - last_improvement > require_improvement: print("No improvement found in a while, stopping optimization.") break # for plotting if i_global == 1: print("Plotting input imgs...") input_imgs = batch_X_train[:9] input_imgs = np.reshape(input_imgs, [-1, 60, 60]) plt.clf() for j in range(9): plt.subplot(3, 3, j+1) plt.imshow(input_imgs[j], cmap='gray') plt.axis('off') fig.canvas.draw() plt.savefig(vis_path + 'epoch_0.png', bbox_inches='tight') # plotting thetas = sess.run(h_trans, feed_dict={X: batch_X_train, phase: True}) thetas = thetas[0:9].squeeze() plt.clf() for j in range(9): plt.subplot(3, 3, j+1) plt.imshow(thetas[j], cmap='gray') plt.axis('off') fig.canvas.draw() plt.savefig(vis_path + 'epoch_' + str(i_global) + '.png', bbox_inches='tight') toc = time.time() print("Time: {:.2f}s".format(toc-tic)) print("Best valid acc: {}".format(best_validation_accuracy)) else: test_accuracy = test_acc(sess, accuracy, X_test, y_test) print("Test Set Accuracy: {}".format(test_accuracy))
def spatial_transformer_layer(name_scope, input_tensor, img_size, kernel_size, pooling=None, strides=[1, 1, 1, 1], pool_strides=[1, 1, 1, 1], activation=tf.nn.relu, use_bn=False, use_mvn=False, is_training=False, use_lrn=False, keep_prob=1.0, dropout_maps=False, init_opt=0, bias_init=0.1): """ Define spatial transformer network layer Args: scope_or_name: `string` or `VariableScope`, the scope to open. inputs: `4-D Tensor`, it is assumed that `inputs` is shaped `[batch_size, Y, X, Z]`. kernel: `4-D Tensor`, [kernel_height, kernel_width, in_channels, out_channels] kernel. img_size: 2D array, [image_width. image_height] bias: `1-D Tensor`, [out_channels] bias. strides: list of `ints`, length 4, the stride of the sliding window for each dimension of `inputs`. activation: activation function to be used (default: `tf.nn.relu`). use_bn: `bool`, whether or not to include batch normalization in the layer. is_training: `bool`, whether or not the layer is in training mode. This is only used if `use_bn` == True. use_lrn: `bool`, whether or not to include local response normalization in the layer. keep_prob: `double`, dropout keep prob. dropout_maps: `bool`, If true whole maps are dropped or not, otherwise single elements. padding: `string` from 'SAME', 'VALID'. The type of padding algorithm used in the convolution. Returns: `4-D Tensor`, has the same type `inputs`. """ img_height = img_size[0] img_width = img_size[1] with tf.variable_scope(name_scope): if init_opt == 0: stddev = np.sqrt(2 / (kernel_size[0] * kernel_size[1] * kernel_size[2] * kernel_size[3])) elif init_opt == 1: stddev = 5e-2 elif init_opt == 2: stddev = min( np.sqrt(2.0 / (kernel_size[0] * kernel_size[1] * kernel_size[2])), 5e-2) kernel = tf.get_variable( 'weights', kernel_size, initializer=tf.random_normal_initializer(stddev=stddev)) conv = tf.nn.conv2d(input_tensor, kernel, strides, padding='SAME', name='conv') bias = tf.get_variable( 'bias', kernel_size[3], initializer=tf.constant_initializer(value=bias_init)) output_tensor = tf.nn.bias_add(conv, bias, name='pre_activation') if activation: output_tensor = activation(output_tensor, name='activation') if use_lrn: output_tensor = tf.nn.local_response_normalization( output_tensor, name='local_responsive_normalization') if dropout_maps: conv_shape = tf.shape(output_tensor) n_shape = tf.stack([conv_shape[0], 1, 1, conv_shape[3]]) output_tensor = tf.nn.dropout(output_tensor, keep_prob=keep_prob, noise_shape=n_shape) else: output_tensor = tf.nn.dropout(output_tensor, keep_prob=keep_prob) if pooling: output_tensor = tf.nn.max_pool(output_tensor, ksize=pooling, strides=pool_strides, padding='VALID') output_tensor = tf.contrib.layers.flatten(output_tensor) output_tensor = tf.contrib.layers.fully_connected( output_tensor, 64, scope='fully_connected_layer_1') output_tensor = tf.nn.tanh(output_tensor) output_tensor = tf.contrib.layers.fully_connected( output_tensor, 6, scope='fully_connected_layer_2') output_tensor = tf.nn.tanh(output_tensor) stn_output = stn(input_fmap=input_tensor, theta=output_tensor, out_dims=(img_height, img_width)) return stn_output, output_tensor
# theta = graph.get_tensor_by_name("network/stn_0/fully_connected_layer_2/weights:0") input_tensor = graph.get_tensor_by_name("train_inputs:0") idx = 0 for i in range(0, 2): batch_x, batch_y = data_provider.get_data('validation') train_feed_dict = { input_tensor: batch_x, } # theta = tf.eye(3, batch_shape=[4]) # theta = tf.eye(num_rows=1, num_columns=9, batch_shape=[4]) # theta = tf.reshape(theta, ([4, -1])) # print("Theta shape, ", theta.get_shape().as_list()) logits = stn(input_tensor, theta) imgs = sess.run([logits], feed_dict=train_feed_dict) imgs = np.array(imgs) batch_y = np.array(batch_y) y_pred = imgs.flatten() y = batch_y.flatten() summation = 0 n = len(y) for i in range(0, n): difference = y[i] - y_pred[i] squared_difference = difference**2 summation = summation + squared_difference MSE = summation / n
def main(): # load the data print("Loading the data...") X_train, y_train, X_test, y_test, X_valid, y_valid = load_data(root_dir) # saniy check dimensions # print("Train: {}".format(X_train.shape)) # print("Test: {}".format(X_test.shape)) # print("Valid: {}".format(X_valid.shape)) # let's view a small sample if VIEW: mask = np.arange(9) gd_truth = np.argmax(y_train[mask], axis=1) sample = X_train.squeeze()[mask] plot_images(sample, gd_truth) if SAMPLE: mask = np.arange(500) X_train = X_train[mask] y_train = y_train[mask] num_train = X_train.shape[0] gd_truth = np.argmax(y_train, axis=1) # # let's check the frequencies of each class # plt.hist(gd_truth, bins=num_classes) # plt.title("Ground Truth Labels") # plt.xlabel("Class") # plt.ylabel("Frequency") # plt.show() print("Building ConvNet...") conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc') pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc') conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc') pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc') pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc) fc1_loc = Dense(pool2_loc_flat, pool2_loc_size, 2048, use_relu=False, name='fc1_loc') fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc') fc3_loc = Dense(fc2_loc, 512, 6, use_relu=False, trans=True, name='fc3_loc') # spatial transformer h_trans = stn(X, fc3_loc) # convnet conv1 = Conv2D(X, 1, 5, 32, name='conv1') bn1 = BatchNormalization(conv1, phase, name='bn1') pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1') conv2 = Conv2D(pool1, 32, 5, 64, name='conv2') bn2 = BatchNormalization(conv2, phase, name='bn2') pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2') conv3 = Conv2D(pool2, 64, 3, 128, name='conv3') bn3 = BatchNormalization(conv3, phase, name='bn3') pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3') pool3_flat, pool3_size = Flatten(pool3) fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1') bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4') fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2') bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5') logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False) # define cost function cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y) loss = tf.reduce_mean(cross_entropy) # define optimizer global_step = tf.Variable(initial_value=0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step) # define accuracy correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # define saver object for storing and retrieving checkpoints saver = tf.train.Saver() if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, 'best_validation') # path for the checkpoint file total_batch = int(np.ceil(num_train / float(batch_size))) num_iterations = num_epochs * total_batch global best_validation_accuracy global last_improvement global require_improvement # create summary for loss and acc tf.summary.scalar('train_loss', loss) tf.summary.scalar('train_accuracy', accuracy) summary_op = tf.summary.merge_all() if not os.path.exists(logs_dir): os.makedirs(logs_dir) logs_path = os.path.join(logs_dir, 'cluttered_mnist/') if not os.path.exists(vis_path): os.makedirs(vis_path) with tf.Session() as sess: if RESTORE: # restore checkpoint if it exists try: print("Trying to restore last checkpoint ...") last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=save_dir) saver.restore(sess, save_path=last_chk_path) print("Restored checkpoint from:", last_chk_path) except: print("Failed to restore checkpoint. Initializing variables instead.") sess.run(tf.global_variables_initializer()) else: sess.run(tf.global_variables_initializer()) # for tensorboard viewing writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph()) # for visualization purposes fig = plt.figure() if MODE == 'train': tic = time.time() print("Training on {} samples, validating on {} samples".format(len(X_train), len(X_valid))) iter_per_epoch, batch_indices = generate_batch_indices(X_train) batch_indices = batch_indices * num_epochs epoch_num = 0 for i in range(num_iterations): # grab the batch index from list idx = batch_indices[i] mask = np.arange(idx[0], idx[1]) # slice into batches batch_X_train, batch_y_train = X_train[mask], y_train[mask] # create feed dict train_feed_dict = {X: batch_X_train, y: batch_y_train, phase: True} i_global, _ = sess.run([global_step, optimizer], feed_dict=train_feed_dict) if (i_global % display_step == 0) or (i == num_iterations - 1): # calculate loss and accuracy on training batch train_batch_loss, train_batch_acc, train_summary = sess.run([loss, accuracy, summary_op], feed_dict=train_feed_dict) writer.add_summary(train_summary, i_global) # calculate loss and accuracy on validation batch valid_batch_loss, valid_batch_acc = validate_acc_loss(sess, loss, accuracy, X_valid, y_valid) # check to see if there's an improvement improved_str = '' if valid_batch_acc > best_validation_accuracy: best_validation_accuracy = valid_batch_acc last_improvement = i_global saver.save(sess=sess, save_path=save_path+str(best_validation_accuracy), global_step=i_global) improved_str = '*' print("Iter: {}/{} - loss: {:.4f} - acc: {:.4f} - val_loss: {:.4f} - val_acc: {:.4f} - {}".format(i_global, num_iterations, train_batch_loss, train_batch_acc, valid_batch_loss, valid_batch_acc, improved_str)) # if no improvement in a while, stop training if i_global - last_improvement > require_improvement: print("No improvement found in a while, stopping optimization.") break # for plotting if i_global == 1: print("Plotting input imgs...") input_imgs = batch_X_train[:9] input_imgs = np.reshape(input_imgs, [-1, 60, 60]) plt.clf() for j in range(9): plt.subplot(3, 3, j+1) plt.imshow(input_imgs[j], cmap='gray') plt.axis('off') fig.canvas.draw() plt.savefig(vis_path + 'epoch_0.png', bbox_inches='tight') # plotting thetas = sess.run(h_trans, feed_dict={X: batch_X_train, phase: True}) thetas = thetas[0:9].squeeze() plt.clf() for j in range(9): plt.subplot(3, 3, j+1) plt.imshow(thetas[j], cmap='gray') plt.axis('off') fig.canvas.draw() plt.savefig(vis_path + 'epoch_' + str(i_global) + '.png', bbox_inches='tight') toc = time.time() print("Time: {:.2f}s".format(toc-tic)) print("Best valid acc: {}".format(best_validation_accuracy)) else: test_accuracy = test_acc(sess, accuracy, X_test, y_test) print("Test Set Accuracy: {}".format(test_accuracy))