Example #1
0
    def scale(self, x_obj, w_scale, w_trans):
        # removes the last dimension
        obj = tf.squeeze(x_obj)
        w_scale = tf.expand_dims(w_scale, -1)
        w_trans_x = tf.expand_dims(w_trans[:, 0], -1)
        w_trans_y = tf.expand_dims(w_trans[:, 1], -1)
        w_trans_z = tf.expand_dims(w_trans[:, 2], -1)

        # !!! note that in their implementation for images, y and x are swiiched,
        # in out implementation it as well: [batch, y, x, z]

        # transform the x and y dim first
        s_xy = tf.matmul(w_scale,
                         tf.constant([[1, 0, 0, 0, 1, 0]], dtype=tf.float32))
        t_xy = tf.matmul(
            w_trans_x,
            tf.constant([[0, 0, 1, 0, 0, 0]], dtype=tf.float32)) + tf.matmul(
                w_trans_y, tf.constant([[0, 0, 0, 0, 0, 1]], dtype=tf.float32))
        T_xy = s_xy + t_xy
        transformed_xy = stn(obj, T_xy)

        # reshape the obj so that it starts with z dim: [batch, x, z, y]
        transposed_transformed_xy = tf.transpose(transformed_xy, [0, 2, 3, 1])

        # transform the z dim
        s_zx = tf.matmul(
            w_scale, tf.constant(
                [[1, 0, 0, 0, 0, 0]], dtype=tf.float32)) + tf.matmul(
                    tf.ones([self.batch_size, 1]),
                    tf.constant([[0, 0, 0, 0, 1, 0]], dtype=tf.float32))
        t_zx = tf.matmul(w_trans_z,
                         tf.constant([[0, 0, 1, 0, 0, 0]], dtype=tf.float32))
        T_zx = s_zx + t_zx
        transformed_zxy = stn(transposed_transformed_xy, T_zx)

        # reshape to the original order: [batch, y, x, z]
        transformed_xyz = tf.transpose(transformed_zxy, [0, 3, 1, 2])

        # add the last dimension back
        transformed_xyz = tf.expand_dims(transformed_xyz, -1)

        print(transformed_xyz)

        return transformed_xyz
def build_convnet():
    # localization network
    conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc')
    pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc')
    conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc')
    pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc')

    pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc)

    fc1_loc = Dense(pool2_loc_flat,
                    pool2_loc_size,
                    2048,
                    use_relu=False,
                    name='fc1_loc')
    fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc')
    fc3_loc = Dense(fc2_loc,
                    512,
                    6,
                    use_relu=False,
                    trans=True,
                    name='fc3_loc')

    print('fc3_loc: {}'.format(fc3_loc.get_shape()))

    # spatial transformer
    h_trans = stn(X, fc3_loc)
    print('h_trans: {}'.format(h_trans.get_shape()))

    # convnet
    conv1 = Conv2D(X, 1, 5, 32, name='conv1')
    bn1 = BatchNormalization(conv1, phase, name='bn1')
    pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1')

    conv2 = Conv2D(pool1, 32, 5, 64, name='conv2')
    bn2 = BatchNormalization(conv2, phase, name='bn2')
    pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2')

    conv3 = Conv2D(pool2, 64, 3, 128, name='conv3')
    bn3 = BatchNormalization(conv3, phase, name='bn3')
    pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3')

    pool3_flat, pool3_size = Flatten(pool3)

    fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1')
    bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4')
    fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2')
    bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5')
    logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False)

    return h_trans, logits
Example #3
0
def focal_loss(target_tensor, theta, org, weights=None, alpha=0.25, gamma=2):
    r"""Compute focal loss for predictions.
        Multi-labels Focal loss formula:
            FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p)
                 ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor.
    Args:
     prediction_tensor: A float tensor of shape [batch_size, num_anchors,
        num_classes] representing the predicted logits for each class
     target_tensor: A float tensor of shape [batch_size, num_anchors,
        num_classes] representing one-hot encoded classification targets
     weights: A float tensor of shape [batch_size, num_anchors]
     alpha: A scalar tensor for focal loss alpha hyper-parameter
     gamma: A scalar tensor for focal loss gamma hyper-parameter
    Returns:
        loss: A (scalar) tensor representing the value of the loss function
    """
    prediction_tensor = stn(org, theta)
    prediction_tensor = tf.to_int32(prediction_tensor > 0.5)
    prediction_tensor = tf.one_hot(prediction_tensor, depth=2)
    prediction_tensor = tf.dtypes.cast(prediction_tensor, tf.float32)

    # target_tensor = tf.convert_to_tensor(target_tensor, tf.int32)
    target_tensor = tf.dtypes.cast(target_tensor, tf.int32)
    target_tensor = tf.one_hot(target_tensor, depth=2)
    target_tensor = tf.dtypes.cast(target_tensor, tf.float32)

    prediction_tensor = tf.convert_to_tensor(prediction_tensor, tf.float32)
    target_tensor = tf.convert_to_tensor(target_tensor, tf.float32)

    print("Target tensor shape", target_tensor.get_shape().as_list())
    print("Prediction tensor shape", prediction_tensor.get_shape().as_list())
    sigmoid_p = tf.nn.sigmoid(prediction_tensor)

    zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)

    # For poitive prediction, only need consider front part loss, back part is 0;
    # target_tensor > zeros <=> z=1, so poitive coefficient = z - p.
    pos_p_sub = array_ops.where(target_tensor > zeros,
                                target_tensor - sigmoid_p, zeros)

    # For negative prediction, only need consider back part loss, front part is 0;
    # target_tensor > zeros <=> z=1, so negative coefficient = 0.
    neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p)
    per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \
                          - (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0))
    return tf.reduce_sum(per_entry_cross_ent)
Example #4
0
def focal_loss_(labels, theta, org, gamma=2.0, alpha=4.0):
    logits = stn(org, theta)
    # logits = (0.5 > logits).float() * 1
    logits = tf.cast(logits + 0.5, tf.float32)
    # logits = tf.one_hot(logits, depth=2)

    epsilon = 1.e-9

    labels = tf.convert_to_tensor(labels, tf.float32)
    logits = tf.convert_to_tensor(logits, tf.float32)

    logits = tf.nn.softmax(logits, dim=-1)
    model_out = tf.add(logits, epsilon)

    ce = tf.multiply(labels, -tf.log(model_out))
    weight = tf.multiply(labels, tf.pow(tf.subtract(1., model_out), gamma))
    fl = tf.multiply(alpha, tf.multiply(weight, ce))
    reduced_fl = tf.reduce_max(fl, axis=1)
    return reduced_fl
def build_convnet():
	# localization network
	conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc')
	pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc')
	conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc')
	pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc')

	pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc)

	fc1_loc = Dense(pool2_loc_flat, pool2_loc_size, 2048, use_relu=False, name='fc1_loc')
	fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc')
	fc3_loc = Dense(fc2_loc, 512, 6, use_relu=False, trans=True, name='fc3_loc')

	# spatial transformer
	h_trans = stn(X, fc3_loc)

	# convnet
	conv1 = Conv2D(X, 1, 5, 32, name='conv1')
	bn1 = BatchNormalization(conv1, phase, name='bn1')
	pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1')

	conv2 = Conv2D(pool1, 32, 5, 64, name='conv2')
	bn2 = BatchNormalization(conv2, phase, name='bn2')
	pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2')

	conv3 = Conv2D(pool2, 64, 3, 128, name='conv3')
	bn3 = BatchNormalization(conv3, phase, name='bn3')
	pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3')

	pool3_flat, pool3_size = Flatten(pool3)

	fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1')
	bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4')
	fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2')
	bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5')
	logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False)

	return h_trans, logits
Example #6
0
def classification_loss(labels, theta, org):
    logits = stn(org, theta)

    n_class = 1
    flat_logits = tf.reshape(logits, [-1])
    flat_labels = tf.reshape(labels, [-1])

    # print(tf.shape(flat_logits))
    # print(tf.shape(flat_labels))

    loss = tf.losses.mean_squared_error(flat_labels, flat_logits)

    # flat_logits = tf.multiply(flat_logits, 255.0)
    # flat_labels = tf.multiply(flat_labels, 255.0)

    # flat_logits = tf.dtypes.cast(flat_logits, dtype=tf.int32)
    # flat_labels = tf.dtypes.cast(flat_labels, dtype=tf.int32)

    # accuracy, update_op = tf.metrics.accuracy(labels=flat_labels[0],
    #                                       predictions=flat_logits[0])

    # return loss, accuracy
    return loss
Example #7
0
def main():

	# load the data
	print("Loading the data...")
	X_train, y_train, X_test, y_test, X_valid, y_valid = load_data(root_dir)

	# sanity check dimensions
	# print("Train: {}".format(X_train.shape))
	# print("Test: {}".format(X_test.shape))
	# print("Valid: {}".format(X_valid.shape))

	# let's view a small sample
	if VIEW:
		mask = np.arange(9)
		gd_truth = np.argmax(y_train[mask], axis=1)
		sample = X_train.squeeze()[mask]
		plot_images(sample, gd_truth)

	if SAMPLE:
		mask = np.arange(500)
		X_train = X_train[mask]
		y_train = y_train[mask]

	num_train = X_train.shape[0]
	gd_truth = np.argmax(y_train, axis=1)

	# let's check the frequencies of each class
	# plt.hist(gd_truth, bins=num_classes)
	# plt.title("Ground Truth Labels")
	# plt.xlabel("Class")
	# plt.ylabel("Frequency")
	# plt.show()

	print("Building ConvNet...")
	conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc')
	pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc')
	conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc')
	pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc')

	pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc)

	fc1_loc = Dense(pool2_loc_flat, pool2_loc_size, 2048, use_relu=False, name='fc1_loc')
	fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc')
	fc3_loc = Dense(fc2_loc, 512, 6, use_relu=False, trans=True, name='fc3_loc')

	# spatial transformer
	h_trans = stn(X, fc3_loc, out_dims=(H, W))

	# convnet
	conv1 = Conv2D(h_trans, 1, 5, 32, name='conv1')
	bn1 = BatchNormalization(conv1, phase, name='bn1')
	pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1')

	conv2 = Conv2D(pool1, 32, 5, 64, name='conv2')
	bn2 = BatchNormalization(conv2, phase, name='bn2')
	pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2')

	conv3 = Conv2D(pool2, 64, 3, 128, name='conv3')
	bn3 = BatchNormalization(conv3, phase, name='bn3')
	pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3')

	pool3_flat, pool3_size = Flatten(pool3)

	fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1')
	bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4')
	fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2')
	bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5')
	logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False)

	# define cost function
	cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)
	loss = tf.reduce_mean(cross_entropy)

	# define optimizer
	global_step = tf.Variable(initial_value=0, name='global_step', trainable=False)
	optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step)

	# define accuracy
	correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
	accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

	# define saver object for storing and retrieving checkpoints
	saver = tf.train.Saver()
	if not os.path.exists(save_dir):
		os.makedirs(save_dir)
	save_path = os.path.join(save_dir, 'best_validation') # path for the checkpoint file

	total_batch = int(np.ceil(num_train / float(batch_size)))
	num_iterations = num_epochs * total_batch

	global best_validation_accuracy
	global last_improvement
	global require_improvement

	# create summary for loss and acc
	tf.summary.scalar('train_loss', loss)
	tf.summary.scalar('train_accuracy', accuracy)
	summary_op = tf.summary.merge_all()

	if not os.path.exists(logs_dir):
		os.makedirs(logs_dir)
	logs_path = os.path.join(logs_dir, 'cluttered_mnist/')

	if not os.path.exists(vis_path):
		os.makedirs(vis_path)

	with tf.Session() as sess:

		if RESTORE:
			# restore checkpoint if it exists
			try:
				print("Trying to restore last checkpoint...")
				last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
				saver.restore(sess, save_path=last_chk_path)
				print("Restored checkpoint from: ", last_chk_path)
			except:
				print("Failed to restore checkpoint. Initializing variables instead.")
				sess.run(tf.global_variables_initializer())
		else:
			sess.run(tf.global_variables_initializer())

		# for tensorboard viewing
		writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())

		# for visualization purposes
		fig = plt.figure()

		if MODE == 'train':

			tic = time.time()
			print("Training on {} samples, validating on {} samples".format(len(X_train), len(X_valid)))

			iter_per_epoch, batch_indices = generate_batch_indices(X_train)
			batch_indices = batch_indices * num_epochs
			epoch_num = 0

			for i in range(num_iterations):

				# grab the batch index from list
				idx = batch_indices[i]
				mask = np.arange(idx[0], idx[1])

				# slice into batches
				batch_X_train, batch_y_train = X_train[mask], y_train[mask]

				# create feed dict
				train_feed_dict = {X: batch_X_train, y: batch_y_train, phase: True}

				i_global, _ = sess.run([global_step, optimizer], feed_dict=train_feed_dict)

				if (i_global % display_step == 0) or (i == num_iterations - 1):

					# calculate loss and accuracy on training batch
					train_batch_loss, train_batch_acc, train_summary = sess.run([loss, accuracy, summary_op], feed_dict=train_feed_dict)
					writer.add_summary(train_summary, i_global)

					# calculate loss and accuracy on validation batch
					valid_batch_loss, valid_batch_acc = validate_acc_loss(sess, loss, accuracy, X_valid, y_valid)

					# check to see if there's an improvement
					improved_str = ''
					if valid_batch_acc > best_validation_accuracy:
						best_validation_accuracy = valid_batch_acc
						last_improvement = i_global
						saver.save(sess=sess, save_path=save_path+str(best_validation_accuracy), global_step=i_global)
						improved_str = '*'

					print("Iter: {}/{} - loss: {:.4f} - acc: {:.4f} - val_loss: {:.4f} - val_acc: {:.4f} - {}".format(i_global,
							num_iterations, train_batch_loss, train_batch_acc, valid_batch_loss, valid_batch_acc, improved_str))

				# if no improvement in a while, stop training
				if i_global - last_improvement > require_improvement:
					print("No improvement found in a while, stopping optimization.")
					break

				# for plotting
				if i_global == 1:
					print("Plotting input imgs...")
					input_imgs = batch_X_train[:9]
					input_imgs = np.reshape(input_imgs, [-1, 60, 60])
					plt.clf()
					for j in range(9):
						plt.subplot(3, 3, j+1)
						plt.imshow(input_imgs[j], cmap='gray')
						plt.axis('off')
					fig.canvas.draw()
					plt.savefig(vis_path + 'epoch_0.png', bbox_inches='tight')

				# plotting
				thetas = sess.run(h_trans, feed_dict={X: batch_X_train, phase: True})
				thetas = thetas[0:9].squeeze()
				plt.clf()
				for j in range(9):
					plt.subplot(3, 3, j+1)
					plt.imshow(thetas[j], cmap='gray')
					plt.axis('off')
				fig.canvas.draw()
				plt.savefig(vis_path + 'epoch_' + str(i_global) + '.png', bbox_inches='tight')

			toc = time.time()
			print("Time: {:.2f}s".format(toc-tic))
			print("Best valid acc: {}".format(best_validation_accuracy))

		else:
			test_accuracy = test_acc(sess, accuracy, X_test, y_test)
			print("Test Set Accuracy: {}".format(test_accuracy))
Example #8
0
def spatial_transformer_layer(name_scope,
                              input_tensor,
                              img_size,
                              kernel_size,
                              pooling=None,
                              strides=[1, 1, 1, 1],
                              pool_strides=[1, 1, 1, 1],
                              activation=tf.nn.relu,
                              use_bn=False,
                              use_mvn=False,
                              is_training=False,
                              use_lrn=False,
                              keep_prob=1.0,
                              dropout_maps=False,
                              init_opt=0,
                              bias_init=0.1):
    """
        Define spatial transformer network layer
        Args:
        scope_or_name: `string` or `VariableScope`, the scope to open.
        inputs: `4-D Tensor`, it is assumed that `inputs` is shaped `[batch_size, Y, X, Z]`.
        kernel: `4-D Tensor`, [kernel_height, kernel_width, in_channels, out_channels] kernel.
        img_size: 2D array, [image_width. image_height]
        bias: `1-D Tensor`, [out_channels] bias.
        strides: list of `ints`, length 4, the stride of the sliding window for each dimension of `inputs`.
        activation: activation function to be used (default: `tf.nn.relu`).
        use_bn: `bool`, whether or not to include batch normalization in the layer.
        is_training: `bool`, whether or not the layer is in training mode. This is only used if `use_bn` == True.
        use_lrn: `bool`, whether or not to include local response normalization in the layer.
        keep_prob: `double`, dropout keep prob.
        dropout_maps: `bool`, If true whole maps are dropped or not, otherwise single elements.
        padding: `string` from 'SAME', 'VALID'. The type of padding algorithm used in the convolution.
    Returns:
        `4-D Tensor`, has the same type `inputs`.
    """

    img_height = img_size[0]
    img_width = img_size[1]

    with tf.variable_scope(name_scope):
        if init_opt == 0:
            stddev = np.sqrt(2 / (kernel_size[0] * kernel_size[1] *
                                  kernel_size[2] * kernel_size[3]))

        elif init_opt == 1:
            stddev = 5e-2

        elif init_opt == 2:
            stddev = min(
                np.sqrt(2.0 /
                        (kernel_size[0] * kernel_size[1] * kernel_size[2])),
                5e-2)

        kernel = tf.get_variable(
            'weights',
            kernel_size,
            initializer=tf.random_normal_initializer(stddev=stddev))

        conv = tf.nn.conv2d(input_tensor,
                            kernel,
                            strides,
                            padding='SAME',
                            name='conv')

        bias = tf.get_variable(
            'bias',
            kernel_size[3],
            initializer=tf.constant_initializer(value=bias_init))

        output_tensor = tf.nn.bias_add(conv, bias, name='pre_activation')

        if activation:
            output_tensor = activation(output_tensor, name='activation')

        if use_lrn:
            output_tensor = tf.nn.local_response_normalization(
                output_tensor, name='local_responsive_normalization')

        if dropout_maps:
            conv_shape = tf.shape(output_tensor)
            n_shape = tf.stack([conv_shape[0], 1, 1, conv_shape[3]])
            output_tensor = tf.nn.dropout(output_tensor,
                                          keep_prob=keep_prob,
                                          noise_shape=n_shape)
        else:
            output_tensor = tf.nn.dropout(output_tensor, keep_prob=keep_prob)

        if pooling:
            output_tensor = tf.nn.max_pool(output_tensor,
                                           ksize=pooling,
                                           strides=pool_strides,
                                           padding='VALID')

        output_tensor = tf.contrib.layers.flatten(output_tensor)

        output_tensor = tf.contrib.layers.fully_connected(
            output_tensor, 64, scope='fully_connected_layer_1')
        output_tensor = tf.nn.tanh(output_tensor)

        output_tensor = tf.contrib.layers.fully_connected(
            output_tensor, 6, scope='fully_connected_layer_2')
        output_tensor = tf.nn.tanh(output_tensor)

        stn_output = stn(input_fmap=input_tensor,
                         theta=output_tensor,
                         out_dims=(img_height, img_width))

        return stn_output, output_tensor
Example #9
0
# theta = graph.get_tensor_by_name("network/stn_0/fully_connected_layer_2/weights:0")
input_tensor = graph.get_tensor_by_name("train_inputs:0")

idx = 0
for i in range(0, 2):
    batch_x, batch_y = data_provider.get_data('validation')
    train_feed_dict = {
        input_tensor: batch_x,
    }

    # theta = tf.eye(3, batch_shape=[4])
    # theta = tf.eye(num_rows=1, num_columns=9, batch_shape=[4])
    # theta = tf.reshape(theta, ([4, -1]))
    # print("Theta shape, ", theta.get_shape().as_list())

    logits = stn(input_tensor, theta)
    imgs = sess.run([logits], feed_dict=train_feed_dict)

    imgs = np.array(imgs)
    batch_y = np.array(batch_y)

    y_pred = imgs.flatten()
    y = batch_y.flatten()

    summation = 0
    n = len(y)
    for i in range(0, n):
        difference = y[i] - y_pred[i]
        squared_difference = difference**2
        summation = summation + squared_difference
    MSE = summation / n
def main():

	# load the data
	print("Loading the data...")
	X_train, y_train, X_test, y_test, X_valid, y_valid = load_data(root_dir)

	# saniy check dimensions
	# print("Train: {}".format(X_train.shape))
	# print("Test: {}".format(X_test.shape))
	# print("Valid: {}".format(X_valid.shape))

	# let's view a small sample
	if VIEW:
		mask = np.arange(9)
		gd_truth = np.argmax(y_train[mask], axis=1)
		sample = X_train.squeeze()[mask]
		plot_images(sample, gd_truth)

	if SAMPLE:
		mask = np.arange(500)
		X_train = X_train[mask]
		y_train = y_train[mask]

	num_train = X_train.shape[0]
	gd_truth = np.argmax(y_train, axis=1)

	# # let's check the frequencies of each class
	# plt.hist(gd_truth, bins=num_classes)
	# plt.title("Ground Truth Labels")
	# plt.xlabel("Class")
	# plt.ylabel("Frequency")
	# plt.show()

	print("Building ConvNet...")
	conv1_loc = Conv2D(X, 1, 5, 32, name='conv1_loc')
	pool1_loc = MaxPooling2D(conv1_loc, use_relu=True, name='pool1_loc')
	conv2_loc = Conv2D(pool1_loc, 32, 5, 64, name='conv2_loc')
	pool2_loc = MaxPooling2D(conv2_loc, use_relu=True, name='pool2_loc')

	pool2_loc_flat, pool2_loc_size = Flatten(pool2_loc)

	fc1_loc = Dense(pool2_loc_flat, pool2_loc_size, 2048, use_relu=False, name='fc1_loc')
	fc2_loc = Dense(fc1_loc, 2048, 512, use_relu=True, name='fc2_loc')
	fc3_loc = Dense(fc2_loc, 512, 6, use_relu=False, trans=True, name='fc3_loc')

	# spatial transformer
	h_trans = stn(X, fc3_loc)

	# convnet
	conv1 = Conv2D(X, 1, 5, 32, name='conv1')
	bn1 = BatchNormalization(conv1, phase, name='bn1')
	pool1 = MaxPooling2D(bn1, use_relu=True, name='pool1')

	conv2 = Conv2D(pool1, 32, 5, 64, name='conv2')
	bn2 = BatchNormalization(conv2, phase, name='bn2')
	pool2 = MaxPooling2D(bn2, use_relu=True, name='pool2')

	conv3 = Conv2D(pool2, 64, 3, 128, name='conv3')
	bn3 = BatchNormalization(conv3, phase, name='bn3')
	pool3 = MaxPooling2D(bn3, use_relu=True, name='pool3')

	pool3_flat, pool3_size = Flatten(pool3)

	fc1 = Dense(pool3_flat, pool3_size, 2048, use_relu=False, name='fc1')
	bn4 = BatchNormalization(fc1, phase, use_relu=True, name='bn4')
	fc2 = Dense(bn4, 2048, 512, use_relu=False, name='fc2')
	bn5 = BatchNormalization(fc2, phase, use_relu=True, name='bn5')
	logits = Dense(bn5, 512, num_classes, name='fc3', use_relu=False)

	# define cost function
	cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)
	loss = tf.reduce_mean(cross_entropy)

	# define optimizer
	global_step = tf.Variable(initial_value=0, name='global_step', trainable=False)
	optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step)

	# define accuracy
	correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
	accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

	# define saver object for storing and retrieving checkpoints
	saver = tf.train.Saver()
	if not os.path.exists(save_dir):
		os.makedirs(save_dir)
	save_path = os.path.join(save_dir, 'best_validation') # path for the checkpoint file

	total_batch = int(np.ceil(num_train / float(batch_size)))
	num_iterations = num_epochs * total_batch

	global best_validation_accuracy
	global last_improvement
	global require_improvement

	# create summary for loss and acc
	tf.summary.scalar('train_loss', loss)
	tf.summary.scalar('train_accuracy', accuracy)
	summary_op = tf.summary.merge_all()

	if not os.path.exists(logs_dir):
		os.makedirs(logs_dir)
	logs_path = os.path.join(logs_dir, 'cluttered_mnist/')

	if not os.path.exists(vis_path):
		os.makedirs(vis_path)

	with tf.Session() as sess:

		if RESTORE:
			# restore checkpoint if it exists
			try:
				print("Trying to restore last checkpoint ...")
				last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
				saver.restore(sess, save_path=last_chk_path)
				print("Restored checkpoint from:", last_chk_path)
			except:
				print("Failed to restore checkpoint. Initializing variables instead.")		
				sess.run(tf.global_variables_initializer())
		else:
			sess.run(tf.global_variables_initializer())

		# for tensorboard viewing
		writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())

		# for visualization purposes
		fig = plt.figure()

		if MODE == 'train':

			tic = time.time()
			print("Training on {} samples, validating on {} samples".format(len(X_train), len(X_valid)))

			iter_per_epoch, batch_indices = generate_batch_indices(X_train)
			batch_indices = batch_indices * num_epochs
			epoch_num = 0

			for i in range(num_iterations):

				# grab the batch index from list 
				idx = batch_indices[i]
				mask = np.arange(idx[0], idx[1])

				# slice into batches
				batch_X_train, batch_y_train = X_train[mask], y_train[mask]

				# create feed dict
				train_feed_dict = {X: batch_X_train, y: batch_y_train, phase: True}

				i_global, _ = sess.run([global_step, optimizer], feed_dict=train_feed_dict)

				if (i_global % display_step == 0) or (i == num_iterations - 1):

					# calculate loss and accuracy on training batch
					train_batch_loss, train_batch_acc, train_summary = sess.run([loss, accuracy, summary_op], feed_dict=train_feed_dict)
					writer.add_summary(train_summary, i_global)

					# calculate loss and accuracy on validation batch
					valid_batch_loss, valid_batch_acc = validate_acc_loss(sess, loss, accuracy, X_valid, y_valid)

					# check to see if there's an improvement
					improved_str = ''
					if valid_batch_acc > best_validation_accuracy:
						best_validation_accuracy = valid_batch_acc
						last_improvement = i_global
						saver.save(sess=sess, save_path=save_path+str(best_validation_accuracy), global_step=i_global)
						improved_str = '*'

					print("Iter: {}/{} - loss: {:.4f} - acc: {:.4f} - val_loss: {:.4f} - val_acc: {:.4f} - {}".format(i_global, 
							num_iterations, train_batch_loss, train_batch_acc, valid_batch_loss, valid_batch_acc, improved_str))

				# if no improvement in a while, stop training
				if i_global - last_improvement > require_improvement:
					print("No improvement found in a while, stopping optimization.")
					break

				# for plotting
				if i_global == 1:
					print("Plotting input imgs...")
					input_imgs = batch_X_train[:9]
					input_imgs = np.reshape(input_imgs, [-1, 60, 60])
					plt.clf()
					for j in range(9):
						plt.subplot(3, 3, j+1)
						plt.imshow(input_imgs[j], cmap='gray')
						plt.axis('off')
					fig.canvas.draw()
					plt.savefig(vis_path + 'epoch_0.png', bbox_inches='tight')

				# plotting
				thetas = sess.run(h_trans, feed_dict={X: batch_X_train, phase: True})
				thetas = thetas[0:9].squeeze()
				plt.clf()
				for j in range(9):
					plt.subplot(3, 3, j+1)
					plt.imshow(thetas[j], cmap='gray')
					plt.axis('off')
				fig.canvas.draw()
				plt.savefig(vis_path + 'epoch_' + str(i_global) + '.png', bbox_inches='tight')

			toc = time.time()
			print("Time: {:.2f}s".format(toc-tic))
			print("Best valid acc: {}".format(best_validation_accuracy))

		else:
			test_accuracy = test_acc(sess, accuracy, X_test, y_test)
			print("Test Set Accuracy: {}".format(test_accuracy))