def residual_block_first(self, x, out_channel, strides, name='unit'): # First residual unit with tf.variable_scope(name) as scope: print('\tBuilding residual unit: %s' % scope.name) self._flops += self._get_bn_flops(x) + self._get_relu_flops(x) x = utils._bn(x, self.is_train, self._global_step, name='bn_1') x = utils._relu(x, name='relu_1') in_channel = x.get_shape().as_list()[-1] # Shortcut if in_channel == out_channel: if strides == 1: shortcut = tf.identity(x) else: shortcut = tf.nn.max_pool(x, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID') self._flops += self._get_data_size(x) else: self._flops += self._get_conv_flops(x, strides, out_channel, strides) shortcut = utils._conv(x, strides, out_channel, strides, name='shortcut') # Residual self._flops += self._get_conv_flops(x, 3, out_channel, strides) x = utils._conv(x, 3, out_channel, strides, name='conv_1') self._flops += self._get_bn_flops(x) + self._get_relu_flops(x) x = utils._bn(x, self.is_train, self._global_step, name='bn_2') x = utils._relu(x, name='relu_2') self._flops += self._get_conv_flops(x, 3, out_channel, 1) x = utils._conv(x, 3, out_channel, 1, name='conv_2') # Merge self._flops += self._get_data_size(x) x = x + shortcut return x
def residual_block(self, x, name='unit'): num_channel = x.get_shape().as_list()[-1] with tf.variable_scope(name) as scope: print('\tBuilding residual unit: %s' % scope.name) # Shortcut shortcut = x # Residual x = utils._bn(x, self.is_train, self._global_step, name='bn_1') x = utils._relu(x, name='relu_1') x = utils._conv(x, 3, num_channel, 1, name='conv_1') x = utils._bn(x, self.is_train, self._global_step, name='bn_2') x = utils._relu(x, name='relu_2') x = utils._conv(x, 3, num_channel, 1, name='conv_2') self._flops += 2 * self._get_conv_flops(x, 3, num_channel, 1) + 2 * self._get_bn_flops(x) + 2 * self._get_relu_flops(x) # Merge self._flops += self._get_data_size(x) x = x + shortcut return x
def _relu(self, x, name="relu"): x = utils._relu(x, 0.0, name) return x
def build_model(self): print('Building model') # Init. conv. print('\tBuilding unit: conv1') conv_1 = utils._conv(self._images, 3, 16, 1, name='conv_1') conv_1_bn = utils._bn(conv_1, self.is_train, self._global_step, name='conv_1_bn') conv1_relu = utils._relu(conv_1_bn, name='conv1_relu') # Residual Blocks #_residual_mod(input,filter_size,kernel_num,has_side_conv,is_stride,is_train,global_step,name=basename) with tf.variable_scope('conv2') as scope: conv2_1 = utils._residual_mod(conv1_relu, 3, 64, True, False, self.is_train, self._global_step, name='conv2_1') conv2_2 = utils._residual_mod(conv2_1, 3, 64, False, False, self.is_train, self._global_step, name='conv2_2') with tf.variable_scope('conv3') as scope: conv3 = utils._residual_mod(conv2_2, 3, 128, False, True, self.is_train, self._global_step, name='conv3') #_inception1(input,filter_size,kernel_num,is_train,global_step,name=basename): with tf.variable_scope('inception') as scope: inception = utils._inception1(conv3, [1, 3, 3, 3, 1], [128, 64, 128, 64, 128], self.is_train, self._global_step, name='inception') with tf.variable_scope('conv4') as scope: conv4 = utils._residual_mod(inception, 3, 256, False, True, self.is_train, self._global_step, name='conv4') conv4_ave_pool = utils._avg_pool(conv4, 'VALID', name='conv4_ave_pool') # Logit with tf.variable_scope('logits') as scope: print('\tBuilding unit: %s' % scope.name) conv4_ave_pool_shape = conv4_ave_pool.get_shape().as_list() dim_conv4_ave_pool = conv4_ave_pool_shape[ 1] * conv4_ave_pool_shape[2] * conv4_ave_pool_shape[3] x = tf.reshape(conv4_ave_pool, [conv4_ave_pool_shape[0], dim_conv4_ave_pool]) #pdb.set_trace() x = utils._fc(x, self._hp.num_classes) print x.get_shape() #pdb.set_trace() self._logits = x self.probs = tf.nn.softmax(x, name='probs') self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) correct = tf.select(tf.equal(self.preds, self._labels), ones, zeros) self.acc = tf.reduce_mean(correct, name='acc') #tf.scalar_summary('accuracy', self.acc) loss = tf.nn.sparse_softmax_cross_entropy_with_logits(x, self._labels) self.loss = tf.reduce_mean(loss, name='cross_entropy')
def _relu(self, x, name="relu"): x = utils._relu(x, 0.0, name) f = self._get_data_size(x) scope_name = tf.get_variable_scope().name + "/" + name self._add_flops_weights(scope_name, f, 0) return x
def build_model(self): print('Building model') # Init. conv. print('\tBuilding unit: init_conv') x = utils._conv(self._images, 3, 16, 1, name='init_conv') # Residual Blocks filters = [16, 16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k] strides = [1, 2, 2] for i in range(1, 4): # First residual unit with tf.variable_scope('unit_%d_0' % i) as scope: print('\tBuilding residual unit: %s' % scope.name) x = utils._bn(x, self.is_train, self._global_step, name='bn_1') x = utils._relu(x, name='relu_1') # Shortcut if filters[i - 1] == filters[i]: if strides[i - 1] == 1: shortcut = tf.identity(x) else: shortcut = tf.nn.max_pool( x, [1, strides[i - 1], strides[i - 1], 1], [1, strides[i - 1], strides[i - 1], 1], 'VALID') else: shortcut = utils._conv(x, 1, filters[i], strides[i - 1], name='shortcut') # Residual x = utils._conv(x, 3, filters[i], strides[i - 1], name='conv_1') x = utils._bn(x, self.is_train, self._global_step, name='bn_2') x = utils._relu(x, name='relu_2') x = utils._conv(x, 3, filters[i], 1, name='conv_2') # Merge x = x + shortcut # Other residual units for j in range(1, self._hp.num_residual_units): with tf.variable_scope('unit_%d_%d' % (i, j)) as scope: print('\tBuilding residual unit: %s' % scope.name) # Shortcut shortcut = x # Residual x = utils._bn(x, self.is_train, self._global_step, name='bn_1') x = utils._relu(x, name='relu_1') x = utils._conv(x, 3, filters[i], 1, name='conv_1') x = utils._bn(x, self.is_train, self._global_step, name='bn_2') x = utils._relu(x, name='relu_2') x = utils._conv(x, 3, filters[i], 1, name='conv_2') # Merge x = x + shortcut # Last unit with tf.variable_scope('unit_last') as scope: print('\tBuilding unit: %s' % scope.name) x = utils._bn(x, self.is_train, self._global_step) x = utils._relu(x) x = tf.reduce_mean(x, [1, 2]) # Logit with tf.variable_scope('logits') as scope: print('\tBuilding unit: %s' % scope.name) x_shape = x.get_shape().as_list() x = tf.reshape(x, [-1, x_shape[1]]) x = utils._fc(x, self._hp.num_classes) self._logits = x # Probs & preds & acc self.probs = tf.nn.softmax(x, name='probs') self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros) self.acc = tf.reduce_mean(correct, name='acc') tf.summary.scalar('accuracy', self.acc) # Loss & acc loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=x, labels=self._labels) self.loss = tf.reduce_mean(loss, name='cross_entropy') tf.summary.scalar('cross_entropy', self.loss)
def build_model(self): print('Building model') filters = [16, 16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k] strides = [1, 2, 2] with tf.variable_scope("group"): if self._hp.ngroups1 > 1: self.split_q1 = utils._get_split_q(self._hp.ngroups1, self._hp.num_classes, name='split_q1') self.split_p1 = utils._get_split_q(self._hp.ngroups1, filters[3], name='split_p1') tf.summary.histogram("group/split_p1/", self.split_p1) tf.summary.histogram("group/split_q1/", self.split_q1) else: self.split_q1 = None self.split_p1 = None if self._hp.ngroups2 > 1: self.split_q2 = utils._merge_split_q( self.split_p1, utils._get_even_merge_idxs(self._hp.ngroups1, self._hp.ngroups2), name='split_q2') self.split_p2 = utils._get_split_q(self._hp.ngroups2, filters[2], name='split_p2') self.split_r21 = utils._get_split_q(self._hp.ngroups2, filters[3], name='split_r21') self.split_r22 = utils._get_split_q(self._hp.ngroups2, filters[3], name='split_r22') tf.summary.histogram("group/split_q2/", self.split_q2) tf.summary.histogram("group/split_p2/", self.split_p2) tf.summary.histogram("group/split_r21/", self.split_r21) tf.summary.histogram("group/split_r22/", self.split_r22) else: self.split_p2 = None self.split_q2 = None self.split_r21 = None self.split_r22 = None if self._hp.ngroups3 > 1: self.split_q3 = utils._merge_split_q( self.split_p2, utils._get_even_merge_idxs(self._hp.ngroups2, self._hp.ngroups3), name='split_q3') self.split_p3 = utils._get_split_q(self._hp.ngroups3, filters[1], name='split_p3') self.split_r31 = utils._get_split_q(self._hp.ngroups3, filters[2], name='split_r31') self.split_r32 = utils._get_split_q(self._hp.ngroups3, filters[2], name='split_r32') tf.summary.histogram("group/split_q3/", self.split_q3) tf.summary.histogram("group/split_p3/", self.split_p3) tf.summary.histogram("group/split_r31/", self.split_r31) tf.summary.histogram("group/split_r32/", self.split_r32) else: self.split_p3 = None self.split_q3 = None self.split_r31 = None self.split_r32 = None # Init. conv. print('\tBuilding unit: init_conv') x = utils._conv(self._images, 3, filters[0], 1, name='init_conv') x = self._residual_block_first(x, filters[1], strides[0], name='unit_1_0') x = self._residual_block(x, name='unit_1_1') x = self._residual_block_first(x, filters[2], strides[1], input_q=self.split_p3, output_q=self.split_q3, split_r=self.split_r31, name='unit_2_0') x = self._residual_block(x, split_q=self.split_q3, split_r=self.split_r32, name='unit_2_1') x = self._residual_block_first(x, filters[3], strides[2], input_q=self.split_p2, output_q=self.split_q2, split_r=self.split_r21, name='unit_3_0') x = self._residual_block(x, split_q=self.split_q2, split_r=self.split_r22, name='unit_3_1') # Last unit with tf.variable_scope('unit_last') as scope: print('\tBuilding unit: %s' % scope.name) x = utils._bn(x, self.is_train, self._global_step) x = utils._relu(x) x = tf.reduce_mean(x, [1, 2]) # Logit with tf.variable_scope('logits') as scope: print('\tBuilding unit: %s' % scope.name) x_shape = x.get_shape().as_list() x = tf.reshape(x, [-1, x_shape[1]]) if self.split_p1 is not None and self.split_q1 is not None: x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout') x = self._fc(x, self._hp.num_classes, input_q=self.split_p1, output_q=self.split_q1) self._logits = x # Probs & preds & acc self.probs = tf.nn.softmax(x, name='probs') self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros) self.acc = tf.reduce_mean(correct, name='acc') tf.summary.scalar('accuracy', self.acc) # Loss & acc loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=x, labels=self._labels) self.loss = tf.reduce_mean(loss) tf.summary.scalar('cross_entropy', self.loss)
def build_model(self): print('Building model') # Init. conv. print('\tBuilding unit: conv1') conv1 = utils._conv(self._images, 3, 64, 1, name='conv1') conv1_bn = utils._bn(conv1, self.is_train, self._global_step, name='conv1_bn') conv1_relu = utils._relu(conv1_bn, name='conv1_relu') # Residual Blocks #_residual_mod(input,filter_size,kernel_num,has_side_conv,is_stride,is_train,global_step,name=basename) #_basic_conv(x, filter_size, out_channel, strides,is_train,global_step,pad='SAME', name='conv'): with tf.variable_scope('conv2') as scope: conv2_1a = utils._basic_conv(conv1_relu, 3, 64, 1, self.is_train, self._global_step, name='conv2_1a') conv2_1a_2_3x3 = utils._conv(conv2_1a, 3, 64, 1, name='conv2_1a_2_3x3') conv2_1b_1x1 = utils._conv(conv1_relu, 1, 64, 1, name='conv2_1b_1x1') conv2_res2_1 = tf.add(conv2_1a_2_3x3, conv2_1b_1x1, name='conv2_res2_1') conv2_2a_1_bn = utils._bn(conv2_res2_1, self.is_train, self._global_step, name='conv2_2a_1_bn') conv2_2a_1_relu = utils._relu(conv2_2a_1_bn, name='conv2_2a_1_relu') conv2_2a_1_3x3 = utils._basic_conv(conv2_2a_1_relu, 3, 64, 1, self.is_train, self._global_step, name='conv2_2a_1_3x3') conv2_2a_2_3x3 = utils._conv(conv2_2a_1_3x3, 3, 64, 1, name='conv2_2a_2_3x3') conv2_res2_2 = tf.add(conv2_2a_2_3x3, conv2_res2_1, name='conv2_res2_2') conv2_res2_2_bn = utils._bn(conv2_res2_2, self.is_train, self._global_step, name='conv2_res2_2_bn') conv2_res2_2_relu = utils._relu(conv2_res2_2_bn, name='conv2_res2_2_relu') with tf.variable_scope('conv3') as scope: conv3_1a_1_3x3 = utils._basic_conv(conv2_res2_2_relu, 3, 256, 2, self.is_train, self._global_step, name='conv3_1a_1_3x3') conv3_1a_2_3x3 = utils._conv(conv3_1a_1_3x3, 3, 256, 1, name='conv3_1a_2_3x3') conv3_1b_1x1 = utils._conv(conv2_res2_2_relu, 1, 256, 2, name='conv3_1b_1x1') conv3_res3_1 = tf.add(conv3_1a_2_3x3, conv3_1b_1x1, name='conv3_res3_1') conv3_res3_2a_bn = utils._bn(conv3_res3_1, self.is_train, self._global_step, name='conv3_res3_2a_bn') conv3_res3_2a_relu = utils._relu(conv3_res3_2a_bn, name='conv3_res3_2a_relu') #_inception1(input,filter_size,kernel_num,is_train,global_step,name=basename): with tf.variable_scope('inception') as scope: inception = utils._inception1(conv3_res3_2a_relu, [1, 3, 3, 3, 1], [256, 128, 256, 128, 256], self.is_train, self._global_step, name='inception') inception_add = tf.add(inception, conv3_res3_1, name='inception_add') inception_bn = utils._bn(inception_add, self.is_train, self._global_step, name='inception_bn') inception_relu = utils._relu(inception_bn, name='inception_relu') with tf.variable_scope('conv4') as scope: conv4_1a_1_3x3 = utils._basic_conv(inception_relu, 3, 256, 2, self.is_train, self._global_step, name='conv4_1a_1_3x3') conv4_1a_2_3x3 = utils._conv(conv4_1a_1_3x3, 3, 256, 1, name='conv4_1a_2_3x3') conv4_1b_1x1 = utils._conv(inception_relu, 1, 256, 2, name='conv4_1b_1x1') conv4_res4_1 = tf.add(conv4_1a_2_3x3, conv4_1b_1x1, name='conv4_res4_1') conv4_res4_2a_1_bn = utils._bn(conv4_res4_1, self.is_train, self._global_step, name='conv4_res4_2a_1_bn') conv4_res4_2a_1_relu = utils._relu(conv4_res4_2a_1_bn, name='conv4_res4_2a_1_relu') conv4_res4_2a_2_3x3 = utils._basic_conv(conv4_res4_2a_1_relu, 3, 256, 1, self.is_train, self._global_step, name='conv4_res4_2a_2_3x3') conv4_res4_2a_2_3x3_2 = utils._conv(conv4_res4_2a_2_3x3, 3, 256, 1, name='conv4_res4_2a_2_3x3_2') conv4_res4_2 = tf.add(conv4_res4_2a_2_3x3_2, conv4_res4_1, name='conv4_res4_2') conv4_res4_bn = utils._bn(conv4_res4_2, self.is_train, self._global_step, name='conv4_res4_bn') conv4_res4_relu = utils._relu(conv4_res4_bn, name='conv4_res4_relu') conv4_ave_pool = utils._avg_pool(conv4_res4_relu, 'VALID', name='conv4_ave_pool') # Logit with tf.variable_scope('logits') as scope: print('\tBuilding unit: %s' % scope.name) conv4_ave_pool_shape = conv4_ave_pool.get_shape().as_list() dim_conv4_ave_pool = conv4_ave_pool_shape[ 1] * conv4_ave_pool_shape[2] * conv4_ave_pool_shape[3] x = tf.reshape(conv4_ave_pool, [conv4_ave_pool_shape[0], dim_conv4_ave_pool]) #pdb.set_trace() x = utils._fc(x, self._hp.num_classes) print x.get_shape() #pdb.set_trace() self._logits = x self.probs = tf.nn.softmax(x, name='probs') self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros) self.acc = tf.reduce_mean(correct, name='acc') #tf.scalar_summary('accuracy', self.acc) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=x, labels=self._labels) self.loss = tf.reduce_mean(loss, name='cross_entropy')
def _relu(self, x, name="relu"): x = utils._relu(x, 0.0, name) print('%s: %s' % (name, str(x.get_shape().as_list()))) return x
def build_model(self): print('Building model') # Init. conv. print('\tBuilding unit: init_conv') x = utils._conv(self._images, 3, 16, 1, name='init_conv') # Residual Blocks filters = [16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k] strides = [1, 2, 2] # filter2_split1 = self._split_channels(filters[1], self._split1) # filter3_split1 = self._split_channels(filters[2], self._split1) # filter3_split2 = self._split_channels(filters[2], self._split2) # filter3_split3 = self._split_channels(filters[2], self._split3) split_mul = np.sqrt(2/(1.0/len(self._split1)+1.0/len(self._split2))) print('Multiply split layers\' channels by %f' % split_mul) filter2_split1 = self._split_channels(filters[1], self._split1) filter3_split1 = self._split_channels(int(split_mul*filters[2]), self._split1) filter3_split2 = self._split_channels(int(split_mul*filters[2]), self._split2) filter3_split3 = self._split_channels(int(split_mul*filters[2]), self._split3) x = self.residual_block_first(x, filters[0], strides[0], 'unit_1_0') for j in xrange(1, self._hp.num_residual_units, 1): x = self.residual_block(x, 'unit_1_%d' % (j)) x = self.residual_block_first(x, filters[1], strides[1], 'unit_2_0') for j in xrange(1, self._hp.num_residual_units, 1): x = self.residual_block(x, 'unit_2_%d' % (j)) # Split the first half of 3rd residual group into _split1 # and the second half of 3rd residual group into _split2 x = self.residual_block_first_split(x, filter2_split1, filter3_split1, strides[2], 'unit_3_0') for j in xrange(1, self._hp.num_residual_units, 1): if j < self._hp.num_residual_units / 2: x = self.residual_block_split(x, filter3_split1, 'unit_3_%d' % (j)) else: x = self.residual_block_split(x, filter3_split2, 'unit_3_%d' % (j)) # Last unit with tf.variable_scope('unit_last') as scope: print('\tBuilding unit: %s' % scope.name) x = utils._bn(x, self.is_train, self._global_step) x = utils._relu(x) x = tf.reduce_mean(x, [1, 2]) self._flops += self._get_bn_flops(x) + self._get_relu_flops(x) + self._get_data_size(x) # Logit # Split the last fc layer into _split3 with tf.variable_scope('logits') as scope: print('\tBuilding unit: %s' % scope.name) x_shape = x.get_shape().as_list() x = tf.reshape(x, [-1, x_shape[1]]) x = self.fc_split(x, filter3_split3, self._split3) if not self._hp.no_logit_map: x = tf.transpose(tf.gather(tf.transpose(x), self._logit_map)) self._logits = x # Probs & preds & acc self.probs = tf.nn.softmax(x, name='probs') self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) correct = tf.select(tf.equal(self.preds, self._labels), ones, zeros) self.acc = tf.reduce_mean(correct, name='acc') tf.scalar_summary('accuracy', self.acc) # Loss & acc loss = tf.nn.sparse_softmax_cross_entropy_with_logits(x, self._labels) self.loss = tf.reduce_mean(loss, name='cross_entropy') tf.scalar_summary('cross_entropy', self.loss)
def build_model(self): print('Building model') filters = [16, 16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k] strides = [1, 2, 2] # Init. conv. print('\tBuilding unit: init_conv') x = utils._conv(self._images, 3, filters[0], 1, name='init_conv') # unit_1_x x = self._residual_block_first(x, filters[1], strides[0], name='unit_1_0') x = self._residual_block(x, name='unit_1_1') # unit_2_x if self._hp.ngroups3 == 1: x = self._residual_block_first(x, filters[2], strides[1], name='unit_2_0') x = self._residual_block(x, name='unit_2_1') else: unit_2_0_shortcut_kernel = self._hp.split_params['unit_2_0'][ 'shortcut'] unit_2_0_conv1_kernel = self._hp.split_params['unit_2_0']['conv1'] unit_2_0_conv2_kernel = self._hp.split_params['unit_2_0']['conv2'] unit_2_0_p_perms = self._hp.split_params['unit_2_0']['p_perms'] unit_2_0_q_perms = self._hp.split_params['unit_2_0']['q_perms'] unit_2_0_r_perms = self._hp.split_params['unit_2_0']['r_perms'] with tf.variable_scope('unit_2_0'): shortcut = self._conv_split(x, filters[2], strides[1], unit_2_0_shortcut_kernel, unit_2_0_p_perms, unit_2_0_q_perms, name='shortcut') x = self._conv_split(x, filters[2], strides[1], unit_2_0_conv1_kernel, unit_2_0_p_perms, unit_2_0_r_perms, name='conv_1') x = self._bn(x, name='bn_1') x = self._relu(x, name='relu_1') x = self._conv_split(x, filters[2], 1, unit_2_0_conv2_kernel, unit_2_0_r_perms, unit_2_0_q_perms, name='conv_2') x = self._bn(x, name='bn_2') x = x + shortcut x = self._relu(x, name='relu_2') unit_2_1_conv1_kernel = self._hp.split_params['unit_2_1']['conv1'] unit_2_1_conv2_kernel = self._hp.split_params['unit_2_1']['conv2'] unit_2_1_p_perms = self._hp.split_params['unit_2_1']['p_perms'] unit_2_1_r_perms = self._hp.split_params['unit_2_1']['r_perms'] with tf.variable_scope('unit_2_1'): shortcut = x x = self._conv_split(x, filters[2], 1, unit_2_1_conv1_kernel, unit_2_1_p_perms, unit_2_1_r_perms, name='conv_1') x = self._bn(x, name='bn_1') x = self._relu(x, name='relu_1') x = self._conv_split(x, filters[2], 1, unit_2_1_conv2_kernel, unit_2_1_r_perms, unit_2_1_p_perms, name='conv_2') x = self._bn(x, name='bn_2') x = x + shortcut x = self._relu(x, name='relu_2') # unit_3_x if self._hp.ngroups2 == 1: x = self._residual_block_first(x, filters[3], strides[2], name='unit_3_0') x = self._residual_block(x, name='unit_3_1') else: unit_3_0_shortcut_kernel = self._hp.split_params['unit_3_0'][ 'shortcut'] unit_3_0_conv1_kernel = self._hp.split_params['unit_3_0']['conv1'] unit_3_0_conv2_kernel = self._hp.split_params['unit_3_0']['conv2'] unit_3_0_p_perms = self._hp.split_params['unit_3_0']['p_perms'] unit_3_0_q_perms = self._hp.split_params['unit_3_0']['q_perms'] unit_3_0_r_perms = self._hp.split_params['unit_3_0']['r_perms'] with tf.variable_scope('unit_3_0'): shortcut = self._conv_split(x, filters[3], strides[2], unit_3_0_shortcut_kernel, unit_3_0_p_perms, unit_3_0_q_perms, name='shortcut') x = self._conv_split(x, filters[3], strides[2], unit_3_0_conv1_kernel, unit_3_0_p_perms, unit_3_0_r_perms, name='conv_1') x = self._bn(x, name='bn_1') x = self._relu(x, name='relu_1') x = self._conv_split(x, filters[3], 1, unit_3_0_conv2_kernel, unit_3_0_r_perms, unit_3_0_q_perms, name='conv_2') x = self._bn(x, name='bn_2') x = x + shortcut x = self._relu(x, name='relu_2') unit_3_1_conv1_kernel = self._hp.split_params['unit_3_1']['conv1'] unit_3_1_conv2_kernel = self._hp.split_params['unit_3_1']['conv2'] unit_3_1_p_perms = self._hp.split_params['unit_3_1']['p_perms'] unit_3_1_r_perms = self._hp.split_params['unit_3_1']['r_perms'] with tf.variable_scope('unit_3_1'): shortcut = x x = self._conv_split(x, filters[3], 1, unit_3_1_conv1_kernel, unit_3_1_p_perms, unit_3_1_r_perms, name='conv_1') x = self._bn(x, name='bn_1') x = self._relu(x, name='relu_1') x = self._conv_split(x, filters[3], 1, unit_3_1_conv2_kernel, unit_3_1_r_perms, unit_3_1_p_perms, name='conv_2') x = self._bn(x, name='bn_2') x = x + shortcut x = self._relu(x, name='relu_2') # Last unit with tf.variable_scope('unit_last') as scope: print('\tBuilding unit: %s' % scope.name) x = utils._bn(x, self.is_train, self._global_step) x = utils._relu(x) x = tf.reduce_mean(x, [1, 2]) # Logit logits_weights = self._hp.split_params['logits']['weights'] logits_biases = self._hp.split_params['logits']['biases'] logits_input_perms = self._hp.split_params['logits']['input_perms'] logits_output_perms = self._hp.split_params['logits']['output_perms'] with tf.variable_scope('logits') as scope: print('\tBuilding unit: %s - %d split' % (scope.name, len(logits_weights))) x_offset = 0 x_list = [] for i, (w, b, p) in enumerate( zip(logits_weights, logits_biases, logits_input_perms)): in_dim, out_dim = w.shape x_split = tf.transpose(tf.gather(tf.transpose(x), p)) x_split = self._fc_with_init(x_split, out_dim, init_w=w, init_b=b, name='split%d' % (i + 1)) x_list.append(x_split) x_offset += in_dim x = tf.concat(x_list, 1) output_forward_idx = list(np.concatenate(logits_output_perms)) output_inverse_idx = [ output_forward_idx.index(i) for i in range(self._hp.num_classes) ] x = tf.transpose(tf.gather(tf.transpose(x), output_inverse_idx)) self._logits = x # Probs & preds & acc self.probs = tf.nn.softmax(x, name='probs') self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros) self.acc = tf.reduce_mean(correct, name='acc') tf.summary.scalar('accuracy', self.acc) # Loss & acc loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=x, labels=self._labels) self.loss = tf.reduce_mean(loss) tf.summary.scalar('cross_entropy', self.loss)