def _build_graph(self): from tf_utils.layers import conv2d, max_pool, rescale_bilinear, avg_pool def layer_width(layer: int): # number of channels (features per pixel) return min([4 * 4**(layer + 1), 64]) input_shape = [None] + list(self.input_shape) output_shape = input_shape[:3] + [self.class_count] # Input image and labels placeholders input = tf.placeholder(tf.float32, shape=input_shape) target = tf.placeholder(tf.float32, shape=output_shape) # Downsampled input (to improve speed at the cost of accuracy) h = rescale_bilinear(input, 0.5) # Hidden layers h = conv2d(h, 3, layer_width(0)) h = tf.nn.relu(h) for l in range(1, self.conv_layer_count): h = max_pool(h, 2) h = conv2d(h, 3, layer_width(l)) h = tf.nn.relu(h) # Pixelwise softmax classification and label upscaling logits = conv2d(h, 1, self.class_count) probs = tf.nn.softmax(logits) probs = tf.image.resize_bilinear(probs, output_shape[1:3]) # Loss clipped_probs = tf.clip_by_value(probs, 1e-10, 1.0) ts = lambda x: x[:, :, :, 1:] if self.class0_unknown else x cost = -tf.reduce_mean(ts(target) * tf.log(ts(clipped_probs))) # Optimization optimizer = tf.train.AdamOptimizer(self.learning_rate) training_step = optimizer.minimize(cost) # Dense predictions and labels preds, dense_labels = tf.argmax(probs, 3), tf.argmax(target, 3) # Other evaluation measures self._n_accuracy = tf.reduce_mean( tf.cast(tf.equal(preds, dense_labels), tf.float32)) return AbstractModel.EssentialNodes( input=input, target=target, probs=probs, loss=cost, training_step=training_step)
def _build_graph(self, learning_rate, epoch, is_training): from layers import conv # Input image and labels placeholders input_shape = [None] + list(self.input_shape) output_shape = [None, self.class_count] input = tf.placeholder(tf.float32, shape=input_shape) target = tf.placeholder(tf.float32, shape=output_shape) # Hidden layers h = layers_exp.rbf_resnet(input, is_training=is_training, base_width=self.base_width, widening_factor=self.widening_factor, group_lengths=self.group_lengths) # Global pooling and softmax classification h = tf.reduce_mean(h, axis=[1, 2], keep_dims=True) logits = conv(h, 1, self.class_count) logits = tf.reshape(logits, [-1, self.class_count]) probs = tf.nn.softmax(logits) # Loss clipped_probs = tf.clip_by_value(probs, 1e-10, 1.0) loss = -tf.reduce_mean(target * tf.log(clipped_probs)) # Regularization w_vars = filter(lambda x: 'weights' in x.name, tf.global_variables()) loss += self.weight_decay * regularization.l2_regularization(w_vars) # Optimization optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) training_step = optimizer.minimize(loss) # Dense predictions and labels preds, dense_labels = tf.argmax(probs, 1), tf.argmax(target, 1) # Other evaluation measures accuracy = tf.reduce_mean( tf.cast(tf.equal(preds, dense_labels), tf.float32)) #writer = tf.summary.FileWriter('logs', self._sess.graph) return AbstractModel.EssentialNodes(input=input, target=target, probs=probs, loss=loss, training_step=training_step, evaluation={'accuracy': accuracy})
def _build_graph(self, learning_rate, epoch, is_training): from tensorflow.contrib import layers from tf_utils.layers import conv2d, max_pool, rescale_bilinear, avg_pool, bn_relu from tf_utils.losses import multiclass_hinge_loss def get_ortho_penalty(): vars = tf.contrib.framework.get_variables('') filt = lambda x: 'conv' in x.name and 'weights' in x.name weight_vars = list(filter(filt, vars)) loss = tf.constant(0.0) for v in weight_vars: m = tf.reshape(v, (-1, v.shape[3].value)) d = tf.matmul( m, m, True) - tf.eye(v.shape[3].value) / v.shape[3].value loss += tf.reduce_sum(d**2) return loss input_shape = [None] + list(self.input_shape) output_shape = [None, self.class_count] # Input image and labels placeholders input = tf.placeholder(tf.float32, shape=input_shape, name='input') target = tf.placeholder(tf.float32, shape=output_shape, name='target') # L2 regularization weight_decay = tf.constant(self.weight_decay, dtype=tf.float32) # Hidden layers h = input with tf.contrib.framework.arg_scope( [layers.conv2d], kernel_size=5, data_format='NHWC', padding='SAME', activation_fn=tf.nn.relu, weights_initializer=layers.variance_scaling_initializer(), weights_regularizer=layers.l2_regularizer(weight_decay)): h = layers.conv2d(h, 16, scope='convrelu1') h = layers.max_pool2d(h, 2, 2, scope='pool1') h = layers.conv2d(h, 32, scope='convrelu2') h = layers.max_pool2d(h, 2, 2, scope='pool2') with tf.contrib.framework.arg_scope( [layers.fully_connected], activation_fn=tf.nn.relu, weights_initializer=layers.variance_scaling_initializer(), weights_regularizer=layers.l2_regularizer(weight_decay)): h = layers.flatten(h, scope='flatten3') h = layers.fully_connected(h, 512, scope='fc3') self._print_vars() # Softmax classification logits = layers.fully_connected(h, self.class_count, activation_fn=None, scope='logits') probs = tf.nn.softmax(logits, name='probs') # Loss mhl = lambda t, lo: 0.1 * multiclass_hinge_loss(t, lo) sce = tf.losses.softmax_cross_entropy loss = (mhl if self.use_multiclass_hinge_loss else sce)(target, logits) loss = loss + tf.reduce_sum( tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) if self.ortho_penalty > 0: loss += self.ortho_penalty * get_ortho_penalty() # Optimization optimizer = tf.train.AdamOptimizer(learning_rate) training_step = optimizer.minimize(loss) # Dense predictions and labels preds, dense_labels = tf.argmax(probs, 1), tf.argmax(target, 1) # Other evaluation measures accuracy = tf.reduce_mean( tf.cast(tf.equal(preds, dense_labels), tf.float32)) return AbstractModel.EssentialNodes(input=input, target=target, probs=probs, loss=loss, training_step=training_step, evaluation={'accuracy': accuracy})