Exemple #1
0
 def block_forward(self,
                   inp,
                   weights,
                   ss_weights,
                   block,
                   reuse,
                   scope,
                   block_last_layer=False):
     """The function to forward a resnet block during meta-train phase
     Args:
       inp: input feature maps.
       weights: input resnet weights.
       ss_weights: input scaling weights.
       block: the string to indicate which block we are processing.
       reuse: reuse the batch norm weights or not.
       scope: the label to indicate which layer we are processing.
       block_last_layer: whether it is the last layer of this block.
     Return:
       The processed feature maps.
     """
     net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), \
         ss_weights[block + '_bias1'], reuse, scope+block+'0')
     net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv2'), \
         ss_weights[block + '_bias2'], reuse, scope+block+'1')
     res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse,
                                 scope + block + 'res')
     net = net + res
     if block_last_layer:
         net = tf.nn.max_pool(net, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
     net = tf.nn.dropout(net, keep_prob=1)
     return net
Exemple #2
0
 def pretrain_block_forward(self, inp, weights, block, reuse, scope):
     """The function to forward a resnet block during pre-train phase
     Args:
       inp: input feature maps.
       weights: input resnet weights.
       block: the string to indicate which block we are processing.
       reuse: reuse the batch norm weights or not.
       scope: the label to indicate which layer we are processing.
     Return:
       The processed feature maps.
     """
     net = resnet_conv_block(inp, weights[block + '_conv1'],
                             weights[block + '_bias1'], reuse,
                             scope + block + '0')
     net = resnet_conv_block(net, weights[block + '_conv2'],
                             weights[block + '_bias2'], reuse,
                             scope + block + '1')
     net = resnet_conv_block(net, weights[block + '_conv3'],
                             weights[block + '_bias3'], reuse,
                             scope + block + '2')
     res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse,
                                 scope + block + 'res')
     net = net + res
     net = tf.nn.max_pool(net, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
     net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep)
     return net
Exemple #3
0
 def block_forward(self, inp, weights, ss_weights, block, reuse, scope):
     net = resnet_conv_block(inp, self.process_ss_weights(weights, ss_weights, block + '_conv1'), ss_weights[block + '_bias1'], reuse, scope+block+'0')
     net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv2'), ss_weights[block + '_bias2'], reuse, scope+block+'1')
     net = resnet_conv_block(net, self.process_ss_weights(weights, ss_weights, block + '_conv3'), ss_weights[block + '_bias3'], reuse, scope+block+'2')
     res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res')
     net = net + res
     net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'VALID')
     net = tf.nn.dropout(net, keep_prob=1)
     return net
Exemple #4
0
 def pretrain_block_forward(self, inp, weights, block, reuse, scope):
     net = resnet_conv_block(inp, weights[block + '_conv1'], weights[block + '_bias1'], reuse, scope+block+'0')
     net = resnet_conv_block(net, weights[block + '_conv2'], weights[block + '_bias2'], reuse, scope+block+'1')
     net = resnet_conv_block(net, weights[block + '_conv3'], weights[block + '_bias3'], reuse, scope+block+'2')
     res = resnet_nob_conv_block(inp, weights[block + '_conv_res'], reuse, scope+block+'res')
     net = net + res
     net = tf.nn.max_pool(net, [1,2,2,1], [1,2,2,1], 'VALID')
     net = tf.nn.dropout(net, keep_prob=FLAGS.pretrain_dropout_keep)
     return net