Exemple #1
0
    def __init__(self, ckpt_path):
        checkpoints_dir = ckpt_path
        self.batch_size = 10
        # classify initial
        inception_v4_arg_scope = inception_utils.inception_arg_scope()
        self.image_size = 299
        arg_scope = inception_utils.inception_arg_scope()
        number_classes = 1001
        self.input_batch = tf.placeholder(
            tf.float32, shape=[None, self.image_size, self.image_size, 3])
        with slim.arg_scope(arg_scope):
            logits, end_points = inception_v4.inception_v4(
                self.input_batch,
                num_classes=number_classes,
                is_training=False)

        weights_restored_from_file = slim.get_variables_to_restore(
            exclude=['InceptionV4/Logits'])

        init_fn = slim.assign_from_checkpoint_fn(
            os.path.join(checkpoints_dir, 'inception_v4.ckpt'),
            weights_restored_from_file,
            ignore_missing_vars=True)  # assign varible

        self.feature_global = end_points['global_pool']

        self.sess = tf.Session()
        init_fn(self.sess)

        self.dist = DistanceBuilder()
Exemple #2
0
 def __call__(self, x_input):
     if (self.build):
         tf.get_variable_scope().reuse_variables()
     else:
         self.build = True
     inception_imags = (x_input / 255.0 - 0.5) * 2
     resized_images_vgg = tf.image.resize_images(
         x_input, [224, 224]) - tf.constant([123.68, 116.78, 103.94])
     with slim.arg_scope(vgg.vgg_arg_scope()):
         logits_vgg16, _ = self.network_fn_vgg16(
             resized_images_vgg,
             num_classes=self.num_classes,
             is_training=False)
     resized_images_res = (
         tf.image.resize_images(x_input, [224, 224]) / 255.0 - 0.5) * 2
     with slim.arg_scope(resnet_v2.resnet_arg_scope()):
         logits_res, _ = self.network_fn_res(resized_images_res,
                                             num_classes=self.num_classes +
                                             1,
                                             is_training=False)
     logits_res = tf.reshape(logits_res, (-1, 1001))
     logits_res = tf.slice(logits_res, [0, 1],
                           [FLAGS.batch_size, self.num_classes])
     with slim.arg_scope(inception_utils.inception_arg_scope()):
         logits_incepv3, _ = self.network_fn_incepv3(
             inception_imags,
             num_classes=self.num_classes + 1,
             is_training=False)
     logits_incepv3 = tf.slice(logits_incepv3, [0, 1],
                               [FLAGS.batch_size, self.num_classes])
     with slim.arg_scope(inception_utils.inception_arg_scope()):
         logits_incepv4, _ = self.network_fn_incepv4(
             inception_imags,
             num_classes=self.num_classes + 1,
             is_training=False)
     logits_incepv4 = tf.slice(logits_incepv4, [0, 1],
                               [FLAGS.batch_size, self.num_classes])
     with slim.arg_scope(
             inception_resnet_v2.inception_resnet_v2_arg_scope()):
         logits_incep_res, _ = self.network_fn_incep_res(
             inception_imags,
             num_classes=self.num_classes + 1,
             is_training=False)
     logits_incep_res = tf.slice(logits_incep_res, [0, 1],
                                 [FLAGS.batch_size, self.num_classes])
     alex_images = tf.image.resize_images(x_input, [256, 256])
     alex_images = tf.reverse(alex_images, axis=[-1])
     alex_mean_npy = np.load('model/alex_mean.npy').swapaxes(0, 1).swapaxes(
         1, 2).astype(np.float32)
     alex_mean_images = tf.constant(alex_mean_npy)
     alex_images = alex_images[:, ] - alex_mean_images
     alex_images = tf.slice(alex_images, [0, 14, 14, 0],
                            [FLAGS.batch_size, 227, 227, 3])
     _, logits_alex = self.network_fn_alex(alex_images)
     logits = [
         logits_vgg16, logits_res, logits_incepv3, logits_incepv4,
         logits_incep_res, logits_alex
     ]
     ensemble_logits = tf.reduce_mean(tf.stack(logits), 0)
     return ensemble_logits
Exemple #3
0
def init():
    global g_tf_sess, probabilities, label_dict, input_images
    subprocess.run(["git", "clone", "https://github.com/tensorflow/models/"])
    sys.path.append("./models/research/slim")

    parser = argparse.ArgumentParser(
        description="Start a tensorflow model serving")
    parser.add_argument('--model_name', dest="model_name", required=True)
    parser.add_argument('--labels_dir', dest="labels_dir", required=True)
    args, _ = parser.parse_known_args()
    from nets import inception_v3, inception_utils
    label_dict = get_class_label_dict(args.labels_dir)
    classes_num = len(label_dict)
    tf.disable_v2_behavior()
    with slim.arg_scope(inception_utils.inception_arg_scope()):
        input_images = tf.placeholder(tf.float32,
                                      [1, image_size, image_size, num_channel])
        logits, _ = inception_v3.inception_v3(input_images,
                                              num_classes=classes_num,
                                              is_training=False)
        probabilities = tf.argmax(logits, 1)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    g_tf_sess = tf.Session(config=config)
    g_tf_sess.run(tf.global_variables_initializer())
    g_tf_sess.run(tf.local_variables_initializer())

    model_path = Model.get_model_path(args.model_name)
    saver = tf.train.Saver()
    saver.restore(g_tf_sess, model_path)
    def CreateFrozenGraph(checkpointFile, outputFile):
        graph = tf.Graph()
        with tf.Session(graph=graph) as sess:
            with slim.arg_scope(inception_arg_scope()):
                images_input = tf.placeholder(tf.float32,
                                              [None, None, None, 3],
                                              'input_images')
                images_processed = tf.image.resize_images(
                    images_input,
                    (ImageNetEmbedder.EMBED_SIZE, ImageNetEmbedder.EMBED_SIZE))
                images_processed = tf.subtract(
                    images_processed, 0.5)  #Inception net preprocessing
                images_processed = tf.multiply(images_processed, 2.0)
                images_processed = tf.identity(images_processed,
                                               name="processed_images")
                _, end_points = inception_resnet_v2.inception_resnet_v2(
                    images_processed,
                    is_training=False,
                    create_aux_logits=False)
                predictions = tf.identity(end_points['Predictions'],
                                          name="predictions")
                embeddings = tf.identity(end_points['PreLogitsFlatten'],
                                         name="embeddings")

                restorer = tf.train.Saver()
                restorer.restore(sess, checkpointFile)

                output_graph_def = tf.graph_util.convert_variables_to_constants(
                    sess, graph.as_graph_def(),
                    ["processed_images", "predictions", "embeddings"])
                with tf.gfile.GFile(outputFile, "wb") as f:
                    f.write(output_graph_def.SerializeToString())
Exemple #5
0
def model(images, weight_decay=1e-5, is_training=True):
    images = mean_image_subtraction(images)
    with slim.arg_scope(inception_arg_scope(weight_decay=weight_decay)):
        logits, end_points = inception_resnet_v2(images,
                                                 num_classes=None,
                                                 is_training=is_training)
    for key in end_points.keys():
        print(key, end_points[key])
    return logits, end_points
Exemple #6
0
def GoogLeNet(num_classes, weight_decay=0.0, is_training=False):
  func = inception_v1.inception_v1
  arg_scope = inception_utils.inception_arg_scope(weight_decay=weight_decay)
  def network_fn(images, **kwargs):
    with slim.arg_scope(arg_scope):
      return inception_v1.inception_v1(images, num_classes, is_training=is_training)
  if hasattr(func, 'default_image_size'):
    network_fn.default_image_size = func.default_image_size
  return network_fn
Exemple #7
0
def classify(checkpoints_dir, images,logo_names=[""], reuse=False):
    image_size = inception.inception_v4.default_image_size
    probabilities_list = []
    processed_image_list = []
    images_list = []
    output_probabilities = []
    for image in images:
        image = tf.image.decode_jpeg(image, channels=3)
        #image = tf.image.resize_image_with_crop_or_pad(image, 299, 299)
        processed_image = inception_preprocessing.preprocess_image(image,
                                                             image_size,
                                                             image_size,
                                                             is_training=False)
        #processed_images  = tf.expand_dims(processed_image, 0)
        images_list.append(image)
        processed_image_list.append(processed_image)
    with slim.arg_scope(inception_utils.inception_arg_scope()):
        logits, _ = inception.inception_v4(processed_image_list,
                               #num_classes=2,
                               reuse=reuse,
                               is_training=False,
                               logo_names= logo_names)
        probabilities = []
        output_probabilities = []
        for logo_name in logo_names:
            probabilities.append(tf.nn.softmax(logits[logo_name]))

        if tf.gfile.IsDirectory(checkpoints_dir):
          checkpoints_dir = tf.train.latest_checkpoint(checkpoints_dir)

        init_fn = slim.assign_from_checkpoint_fn(
        checkpoints_dir,
        slim.get_model_variables('InceptionV4'),
        ignore_missing_vars=True)
        with tf.Session() as sess:
            init_fn(sess)
            output_probabilities  = sess.run([images_list,
                                             processed_image_list]
                                             + probabilities)[2:]
        output_dict = {}
        print "range(len(output_probabilities): ", range(len(output_probabilities))
        for index in range(len(output_probabilities)):
            print 'logo_names[index]:[{}]'.format(logo_names[index])
            print('type logo_names[index]',type(logo_names[index]));
            if logo_names[index] == "":
                print "here???????????????????"
                output_dict [logo_names[index]] = output_probabilities[index]
            else:
                output_dict [logo_names[index]] = output_probabilities[index]
            print "output_dict [""]: ", output_dict [""].shape

        print "final 0,780 prob output_dict [""]: ", output_dict[""][0][780];
        output_dict[""] = np.argsort(output_dict[""], axis=1)[:, ::-1][:, :5]
        print "final output_dict [""]: ", output_dict [""]
        return output_dict
def endpoints(image, is_training):
    if image.get_shape().ndims != 4:
        raise ValueError('Input must be of size [batch, height, width, 3]')

    image = image - tf.constant(_RGB_MEAN, dtype=tf.float32, shape=(1,1,1,3))

    with tf.contrib.slim.arg_scope(inception_arg_scope(batch_norm_decay=0.9, weight_decay=0.0002)):
        _, endpoints = inception_v4(image, num_classes=None, is_training=is_training)

    endpoints['model_output'] = endpoints['global_pool'] = tf.reduce_mean(
        endpoints['Mixed_7d'], [1, 2], name='pool5', keep_dims=False)

    return endpoints, 'InceptionV4'
Exemple #9
0
 def __init__(self, num_classes):
     self.num_classes = num_classes
     with slim.arg_scope(
             inception_resnet_v2.inception_resnet_v2_arg_scope()):
         self.network_fn_incep_res = inception_resnet_v2.inception_resnet_v2
     with slim.arg_scope(vgg.vgg_arg_scope()):
         self.network_fn_vgg16 = vgg.vgg_16
     with slim.arg_scope(resnet_v2.resnet_arg_scope()):
         self.network_fn_res = resnet_v2.resnet_v2_152
     with slim.arg_scope(inception_utils.inception_arg_scope()):
         self.network_fn_incepv3 = inception_v3.inception_v3
         self.network_fn_incepv4 = inception_v4.inception_v4
     self.network_fn_alex = AlexNet()
     self.build = False
        im3 = im[-260:,0:260,:]
        im4 = im[-260:,-260:,:]
        im5 = im[19:279,19:279,:]
        
        imtemp = [cv2.resize(ims,(input_size[0:2])) for ims in (im1,im2,im3,im4,im5)]
        [img.append(ims) for ims in imtemp]
        
    except:
        print('Exception found: Image not read...')
        pass 
    
    img = np.asarray(img,dtype=np.float32)/255.0
    return img

# ----- Construct network 1 ----- #
with slim.arg_scope(inception_utils.inception_arg_scope()):
    logits,endpoints = inception_v4(data_in1,
                                num_classes=numclasses1,
                                is_training=is_training,
                                scope='herbarium')
    herbarium_embs = endpoints['PreLogitsFlatten']  
    herbarium_bn = tf.layers.batch_normalization(herbarium_embs, training=is_train)
    herbarium_feat = tf.contrib.layers.fully_connected(
                    inputs=herbarium_bn,
                    num_outputs=500,
                    activation_fn=None,
                    normalizer_fn=None,
                    trainable=True,
                    scope='herbarium'
            )
    herbarium_feat = tf.math.l2_normalize(
def eval(gpu, img_q, res_q, end_eval):
    ###############
    # Build graph #
    ###############
    with tf.Graph().as_default():
        print('loading graph on /gpu:%d' % gpu)
        with tf.name_scope('tower-%d' % gpu):
            eval_input = tf.placeholder(tf.float32, [1, height, width, 3])
            with slim.arg_scope(inception_arg_scope()):
                with counter_lock:
                    global counter
                    resue = None if counter == 0 else True
                    counter += 1
                logits, _ = inception_v4(eval_input, num_classes,
                                         reuse=resue, scope='InceptionV4',
                                         is_training=False)
            predictions = tf.argmax(logits, 1)
            probabilities = tf.nn.softmax(logits)
        
        config = tf.ConfigProto()
        config.gpu_options.visible_device_list = str(gpu)
        config.gpu_options.allow_growth = True
        # config.log_device_placement = True
        # config.allow_soft_placement = True
        sess = tf.Session(config=config)

        init_fn = slim.assign_from_checkpoint_fn(
            'checkpoints/inception_v4.ckpt',
            slim.get_model_variables('InceptionV4'))
        init_fn(sess)

    ##############
    # Eval image #
    ##############
    st = time.time()
    img_cnt = 0
    c1, c5 = 0, 0
    while not end_eval.is_set():
        try:
            processed_images, label = img_q.get(False)
        except Q.Empty:
            continue
        if fetch_test:
            img_cnt += 1
            sys.stdout.write('\r{:07d}/{:07d}'.format(img_cnt, img_tlt))
            sys.stdout.flush()
            continue
        pred, prob = sess.run([predictions, probabilities], feed_dict={
                eval_input: processed_images
            })
        prob = prob[0, 0:]
        sorted_inds = [i[0] for i in sorted(enumerate(-prob), key=lambda x:x[1])]

        top1 = sorted_inds[0]
        top5 = sorted_inds[0:5]
        if label == top1:
            c1 += 1
        if label in top5:
            c5 += 1
        idx = img_cnt
        print('gpu %d images: %d\ttop 1: %0.4f\ttop 5: %0.4f' % (gpu, idx + 1, c1/(idx + 1), c5/(idx + 1)))
        img_cnt += 1
    res_q.put(EvalRes(c1, c5, img_cnt))
    print('gpu %d finish in %.4f sec' % (gpu, time.time()-st))
    imresize = lambda x, dim: scipy.misc.imresize(x, dim)

from PIL import Image
from nets.inception_v4 import inception_v4
from nets.inception_utils import inception_arg_scope

slim = tf.contrib.slim

height, width = 299, 299
num_classes = 1001
c1, c5 = 0, 0
lines = np.loadtxt('imagenet/val.txt', str, delimiter='\n')

with tf.Graph().as_default():
    eval_inputs = tf.placeholder(tf.float32, [1, height, width, 3])
    with slim.arg_scope(inception_arg_scope()):
        logits, _ = inception_v4(eval_inputs, num_classes, is_training=False)
    predictions = tf.argmax(logits, 1)
    probabilities = tf.nn.softmax(logits)

    init_fn = slim.assign_from_checkpoint_fn(
        'checkpoints/inception_v4.ckpt',
        slim.get_model_variables('InceptionV4'))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init_fn(sess)

    for idx, line in enumerate(lines):
    def __init__(self,
                 batch,
                 iterbatch,
                 numclasses,
                 image_dir_parent_train,
                 image_dir_parent_test,
                 train_file,
                 test_file,
                 input_size,
                 checkpoint_model,
                 learning_rate,
                 save_dir,
                 max_iter,
                 val_freq,
                 val_iter):
        """
        Wrapper module for multitask network using inceptionV4
        """
        self.batch = batch
        self.iterbatch = iterbatch
        self.image_dir_parent_train = image_dir_parent_train
        self.image_dir_parent_test = image_dir_parent_test
        self.train_file = train_file
        self.test_file = test_file
        self.input_size = input_size
        self.numclasses = numclasses
        self.checkpoint_model = checkpoint_model
        self.learning_rate = learning_rate
        self.save_dir = save_dir
        self.max_iter = max_iter
        self.val_freq = val_freq
        self.val_iter = val_iter
        
        logfile = holmeslog()
        logfile.logging('Initiating database...')
        
        # construct datalist and cursor
        self.train_database = database_module(
                image_source_dir = self.image_dir_parent_train,
                database_file = self.train_file,
                batch = self.batch,
                input_size = self.input_size,
                numclasses = self.numclasses,
                shuffle = True)

        self.test_database = database_module(
                image_source_dir = self.image_dir_parent_test,
                database_file = self.test_file,
                batch = self.batch,
                input_size = self.input_size,
                numclasses = self.numclasses,
                shuffle = True)
        

        logfile.logging('Initiating tensors...')
        x = tf.placeholder(tf.float32,(self.batch,) + self.input_size)
        y1 = tf.placeholder(tf.int32,(self.batch,))
        y2 = tf.placeholder(tf.int32,(self.batch,))
        y3 = tf.placeholder(tf.int32,(self.batch,))
        # 10000 classes for species (classid)
        y_onehot1 = tf.one_hot(y1,self.numclasses)
        # 248 classes for family
        y_onehot2 = tf.one_hot(y2,248)
        # 1780 classes for genus
        y_onehot3 = tf.one_hot(y3,1780)
        self.is_training = tf.placeholder(tf.bool)

        train_preproc = lambda xi: inception_preprocessing.preprocess_image(
                xi,self.input_size[0],self.input_size[1],is_training=True)
        
        def data_in_train():
            return tf.map_fn(fn = train_preproc,elems = x,dtype=np.float32)
        
        test_preproc = lambda xi: inception_preprocessing.preprocess_image(
                xi,self.input_size[0],self.input_size[1],is_training=False)        
        
        def data_in_test():
            return tf.map_fn(fn = test_preproc,elems = x,dtype=np.float32)
        
        data_in = tf.cond(
                self.is_training,
                true_fn = data_in_train,
                false_fn = data_in_test
                )

        logfile.logging('Constructing network...')        

        with slim.arg_scope(inception_utils.inception_arg_scope()):
            logits,endpoints = inception_v4(data_in,
                                            num_classes=self.numclasses,
                                            is_training=self.is_training)
            
            logits_family = slim.fully_connected(endpoints['PreLogitsFlatten'],248,activation_fn=None,
                                        scope='Family')
            logits_genus = slim.fully_connected(endpoints['PreLogitsFlatten'],1780,activation_fn=None,
                                        scope='Genus')
            
        with tf.name_scope("cross_entropy"): 
            with tf.name_scope("auxloss"):
                self.auxloss = tf.reduce_mean(
                        tf.nn.softmax_cross_entropy_with_logits_v2(
                                logits=endpoints['AuxLogits'], labels=y_onehot1))
            with tf.name_scope("logits_loss"):
                self.loss = tf.reduce_mean(
                        tf.nn.softmax_cross_entropy_with_logits_v2(
                                logits=logits, labels=y_onehot1))

            with tf.name_scope("logits_loss_family"):
                self.loss_family = tf.reduce_mean(
                        tf.nn.softmax_cross_entropy_with_logits_v2(
                                logits=logits_family, labels=y_onehot2))

            with tf.name_scope("logits_loss_genus"):
                self.loss_genus = tf.reduce_mean(
                        tf.nn.softmax_cross_entropy_with_logits_v2(
                                logits=logits_genus, labels=y_onehot3))

            with tf.name_scope("L2_reg_loss"):
                self.regularization_loss = tf.add_n(tf.losses.get_regularization_losses(scope='InceptionV4'))
            with tf.name_scope("total_loss"):
                self.totalloss = self.loss + self.loss_family + self.loss_genus + self.auxloss + self.regularization_loss
           

        with tf.name_scope("accuracy"):
            with tf.name_scope('accuracy_species'):
                prediction = tf.argmax(logits,1)
                match = tf.equal(prediction,tf.argmax(y_onehot1,1))
                self.accuracy = tf.reduce_mean(tf.cast(match,tf.float32))  
            with tf.name_scope('accuracy_family'):
                prediction2 = tf.argmax(logits_family,1)
                match = tf.equal(prediction2,tf.argmax(y_onehot2,1))            
                self.accuracy_family = tf.reduce_mean(tf.cast(match,tf.float32)) 
            with tf.name_scope('accuracy_genus'):       
                prediction3 = tf.argmax(logits_genus,1)
                match = tf.equal(prediction3,tf.argmax(y_onehot3,1))            
                self.accuracy_genus = tf.reduce_mean(tf.cast(match,tf.float32))             
            
        self.var_list = [v for v in tf.trainable_variables()]
        self.var_list_conv = self.var_list[0:-12]
        self.var_list_fc = self.var_list[-12:]

        self.var_list_train = self.var_list
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        

        # only load the conv layer but not the logits
        self.variables_to_restore = slim.get_variables_to_restore()
        restore_fn = slim.assign_from_checkpoint_fn(
            self.checkpoint_model, self.variables_to_restore[0:-12])   
        
        with tf.name_scope("train"):
            # the training involves iter_batch updating
            loss_accumulator = tf.Variable(0.0, trainable=False)
            acc_accumulator = tf.Variable(0.0, trainable=False)
            acc_accumulator_family = tf.Variable(0.0, trainable=False)
            acc_accumulator_genus = tf.Variable(0.0, trainable=False)
            
            self.collect_loss = loss_accumulator.assign_add(self.totalloss)
            self.collect_acc = acc_accumulator.assign_add(self.accuracy)
            self.collect_acc_family = acc_accumulator_family.assign_add(self.accuracy_family)
            self.collect_acc_genus = acc_accumulator_genus.assign_add(self.accuracy_genus)

                        
            self.average_loss = tf.cond(self.is_training,
                                        lambda: loss_accumulator / self.iterbatch,
                                        lambda: loss_accumulator / self.val_iter)
            self.average_acc = tf.cond(self.is_training,
                                       lambda: acc_accumulator / self.iterbatch,
                                       lambda: acc_accumulator / self.val_iter)
            self.average_acc_family = tf.cond(self.is_training,
                                       lambda: acc_accumulator_family / self.iterbatch,
                                       lambda: acc_accumulator_family / self.val_iter)
            self.average_acc_genus = tf.cond(self.is_training,
                                       lambda: acc_accumulator_genus / self.iterbatch,
                                       lambda: acc_accumulator_genus / self.val_iter)

            
            self.zero_op_loss = tf.assign(loss_accumulator,0.0)
            self.zero_op_acc = tf.assign(acc_accumulator,0.0)
            self.zero_op_acc_family = tf.assign(acc_accumulator_family,0.0)
            self.zero_op_acc_genus = tf.assign(acc_accumulator_genus,0.0)

            self.accum_train = [tf.Variable(tf.zeros_like(
                    tv.initialized_value()), trainable=False) for tv in self.var_list_train]                                        
            self.zero_ops_train = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_train]
            with tf.control_dependencies(self.update_ops):
                optimizer = tf.train.AdamOptimizer(self.learning_rate)
                gradient = optimizer.compute_gradients(self.totalloss,self.var_list_train)
                gradient_only = [gc[0] for gc in gradient]
                gradient_only,_ = tf.clip_by_global_norm(gradient_only,1.25)
                
                self.accum_train_ops = [self.accum_train[i].assign_add(gc) for i,gc in enumerate(gradient_only)]
            self.train_step = optimizer.apply_gradients(
                    [(self.accum_train[i], gc[1]) for i, gc in enumerate(gradient)])
           

        var_list = tf.trainable_variables()
        g_list = tf.global_variables()
        bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
        bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
        var_list += bn_moving_vars
        saver = tf.train.Saver(var_list=var_list, max_to_keep=0)

        tf.summary.scalar('loss',self.average_loss) 
        tf.summary.scalar('accuracy',self.average_acc) 
        tf.summary.scalar('accuracy_f',self.average_acc_family)
        tf.summary.scalar('accuracy_g',self.average_acc_genus)
#        self.merged = tf.summary.merge_all()
        self.merged = tf.summary.merge([tf.get_collection(tf.GraphKeys.SUMMARIES,'accuracy'),
                                        tf.get_collection(tf.GraphKeys.SUMMARIES,'loss')])
        tensorboar_dir = '%s_tensorboard'%current_time()
        writer_train = tf.summary.FileWriter(tensorboar_dir+'/train')
        writer_test = tf.summary.FileWriter(tensorboar_dir+'/test')

        logfile.logging('Commencing training...') 
        val_best = 0.0
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            writer_train.add_graph(sess.graph)
            writer_test.add_graph(sess.graph)
            
            restore_fn(sess)
            #saver.restore(sess, r"checkpoints_clef3/best.ckpt")
            
            for i in range(self.max_iter+1):
                try:
                    sess.run(self.zero_ops_train)
                    sess.run([self.zero_op_acc_family,self.zero_op_acc_genus,self.zero_op_acc,self.zero_op_loss])
                    
                    # validations
                    if i % self.val_freq == 0:                        
                        print('Start:%f'%sess.run(loss_accumulator))
                        for j in range(self.val_iter):
                            img,lbl1,lbl2,lbl3 = self.test_database.read_batch()
                            sess.run(
                                        [self.collect_loss,self.collect_acc,self.collect_acc_family,self.collect_acc_genus],
                                        feed_dict = {x : img,
                                                     y1 : lbl1,
                                                     y2 : lbl2,
                                                     y3 : lbl3,
                                                     self.is_training : False
                                        }                                  
                                    )
                            print('[%i]:%f'%(j,sess.run(loss_accumulator)))
                        print('End:%f'%sess.run(loss_accumulator))  
                        s,self.netLoss,self.netAccuracy,self.netAccuracyFamily,self.netAccuracyGenus = sess.run(
                                [self.merged,self.average_loss,self.average_acc,self.average_acc_family,self.average_acc_genus],
                                    feed_dict = {
                                            self.is_training : False
                                    }                            
                                ) 
                        writer_test.add_summary(s, i)
                        logfile.logging('[Valid] Epoch:%i Iter:%i Loss:%f, Accuracy:%f'%(self.train_database.epoch,i,self.netLoss,self.netAccuracy)) 
                        logfile.logging('[Valid] Epoch:%i Iter:%i Loss:%f, Accuracy Family:%f'%(self.train_database.epoch,i,self.netLoss,self.netAccuracyFamily)) 
                        logfile.logging('[Valid] Epoch:%i Iter:%i Loss:%f, Accuracy Genus:%f'%(self.train_database.epoch,i,self.netLoss,self.netAccuracyGenus)) 


                       #print('[Valid] Iter:%i Loss:%f, Accuracy:%f'%(i,self.netLoss,self.netAccuracy))
                        sess.run([self.zero_op_acc_family,self.zero_op_acc_genus,self.zero_op_acc,self.zero_op_loss])
                        
                        if self.netAccuracy > val_best:
                            val_best = self.netAccuracy
                            saver.save(sess, os.path.join(self.save_dir,'best.ckpt'))
                            logfile.logging('Model saved')


                    # training
                    for j in range(self.iterbatch):
                        img,lbl1,lbl2,lbl3 = self.train_database.read_batch()
    
                        sess.run(
                                    [self.collect_loss,self.collect_acc,self.collect_acc_family,self.collect_acc_genus,self.accum_train_ops],
                                    feed_dict = {x : img,
                                                     y1 : lbl1,
                                                     y2 : lbl2,
                                                     y3 : lbl3,
                                                 self.is_training : True
                                    }                                
                                )
                        
                    s,self.netLoss,self.netAccuracy,self.netAccuracyFamily,self.netAccuracyGenus = sess.run(
                            [self.merged,self.average_loss,self.average_acc,self.average_acc_family,self.average_acc_genus],
                                feed_dict = {
                                        self.is_training : True
                                }                            
                            ) 
                    writer_train.add_summary(s, i)
                    
                    
                    sess.run(
                            [self.train_step]
                            )
                        
                    logfile.logging('[Train] Epoch:%i Iter:%i Loss:%f, Accuracy:%f'%(self.train_database.epoch,i,self.netLoss,self.netAccuracy))
                    logfile.logging('[Train] Epoch:%i Iter:%i Loss:%f, Accuracy Family:%f'%(self.train_database.epoch,i,self.netLoss,self.netAccuracyFamily)) 
                    logfile.logging('[Train] Epoch:%i Iter:%i Loss:%f, Accuracy Genus:%f'%(self.train_database.epoch,i,self.netLoss,self.netAccuracyGenus)) 

                    #print('[Train] Iter:%i Loss:%f, Accuracy:%f'%(i,self.netLoss,self.netAccuracy))
                    
                    if i % 5000 == 0:
                        saver.save(sess, os.path.join(self.save_dir,'%06i.ckpt'%i)) 
                    
                except KeyboardInterrupt:
                    logfile.logging('Interrupt detected. Ending...')
                    break
            saver.save(sess, os.path.join(self.save_dir,'final.ckpt')) 
            logfile.logging('Model saved')
Exemple #14
0
    def __init__(self,
                 batch,
                 iterbatch,
                 numclasses1,
                 numclasses2,
                 image_dir_parent_train,
                 image_dir_parent_test,
                 train_file1,
                 train_file2,
                 test_file1,
                 test_file2,
                 input_size,
                 checkpoint_model1,
                 checkpoint_model2,
                 learning_rate,
                 save_dir,
                 max_iter,
                 val_freq,
                 val_iter):
        
        self.batch = batch
        self.iterbatch = iterbatch
        self.image_dir_parent_train = image_dir_parent_train
        self.image_dir_parent_test = image_dir_parent_test
        self.train_file1 = train_file1
        self.train_file2 = train_file2
        self.test_file1 = test_file1
        self.test_file2 = test_file2
        self.input_size = input_size
        self.numclasses1 = numclasses1
        self.numclasses2 = numclasses2
        self.checkpoint_model1 = checkpoint_model1
        self.checkpoint_model2 = checkpoint_model2
        self.learning_rate = learning_rate
        self.save_dir = save_dir
        self.max_iter = max_iter
        self.val_freq = val_freq
        self.val_iter = val_iter

        # ----- Database module ----- #
        self.train_database = database_module(
                image_source_dir = self.image_dir_parent_train,
                database_file1 = self.train_file1,
                database_file2 = self.train_file2,
                batch = self.batch,
                input_size = self.input_size,
                numclasses1 = self.numclasses1,
                numclasses2 = self.numclasses2,
                shuffle = True)

        self.test_database = database_module(
                image_source_dir = self.image_dir_parent_test,
                database_file1 = self.test_file1,
                database_file2 = self.test_file2,
                batch = self.batch,
                input_size = self.input_size,
                numclasses1 = self.numclasses1,
                numclasses2 = self.numclasses2,
                shuffle = True)
       
        # ----- Tensors ------ #
        print('Initiating tensors...')
        x1 = tf.placeholder(tf.float32,(None,) + self.input_size)
        x2 = tf.placeholder(tf.float32,(None,) + self.input_size)
        herbarium_embs = tf.placeholder(tf.float32)
        field_embs = tf.placeholder(tf.float32)
        feat_concat = tf.placeholder(tf.float32, shape=[None, 500])
        lbl_concat = tf.placeholder(tf.float32)
        y1 = tf.placeholder(tf.int32, (None,))
        y2 = tf.placeholder(tf.int32, (None,))
        self.is_training = tf.placeholder(tf.bool)
        is_train = tf.placeholder(tf.bool, name="is_training")
        
        # ----- Image pre-processing methods ----- #      
        train_preproc = lambda xi: inception_preprocessing.preprocess_image(
                xi,self.input_size[0],self.input_size[1],is_training=True)
        
        test_preproc = lambda xi: inception_preprocessing.preprocess_image(
                xi,self.input_size[0],self.input_size[1],is_training=False) 
        
        def data_in_train1():
            return tf.map_fn(fn = train_preproc,elems = x1,dtype=np.float32)
        
        def data_in_test1():
            return tf.map_fn(fn = test_preproc,elems = x1,dtype=np.float32)
        
        def data_in_train2():
            return tf.map_fn(fn = train_preproc,elems = x2,dtype=np.float32)
        
        def data_in_test2():
            return tf.map_fn(fn = test_preproc,elems = x2,dtype=np.float32)
        
        data_in1 = tf.cond(
                self.is_training,
                true_fn = data_in_train1,
                false_fn = data_in_test1
                )
        
        data_in2 = tf.cond(
                self.is_training,
                true_fn = data_in_train2,
                false_fn = data_in_test2
                )

        print('Constructing network...')        
        # ----- Network 1 construction ----- #
        with slim.arg_scope(inception_utils.inception_arg_scope()):
            logits,endpoints = inception_v4(data_in1,
                                            num_classes=self.numclasses1,
                                            is_training=self.is_training,
                                            scope='herbarium'
                                            )
            herbarium_embs = endpoints['PreLogitsFlatten']
            herbarium_bn = tf.layers.batch_normalization(herbarium_embs, training=is_train)
            herbarium_feat = tf.contrib.layers.fully_connected(
                            inputs=herbarium_bn,
                            num_outputs=500,
                            activation_fn=None,
                            normalizer_fn=None,
                            trainable=True,
                            scope='herbarium'                            
                    )
            herbarium_feat = tf.math.l2_normalize(
                                                herbarium_feat,
                                                axis=1      
                                            )
        
        # ----- Network 2 construction ----- #        
        with slim.arg_scope(inception_utils.inception_arg_scope()):
            logits2,endpoints2 = inception_v4(data_in2,
                                            num_classes=self.numclasses2,
                                            is_training=self.is_training,
                                            scope='field'
                                            )
            field_embs = endpoints2['PreLogitsFlatten']
            field_bn = tf.layers.batch_normalization(field_embs, training=is_train)
            field_feat = tf.contrib.layers.fully_connected(
                            inputs=field_bn,
                            num_outputs=500,
                            activation_fn=None,
                            normalizer_fn=None,
                            trainable=True,
                            scope='field'                            
                    )            
            field_feat = tf.math.l2_normalize(
                                    field_feat,
                                    axis=1      
                                )
        
        feat_concat = tf.concat([herbarium_feat, field_feat], 0)
        lbl_concat = tf.concat([y1, y2], 0)

        # ----- Get all variables ----- #
        self.variables_to_restore = tf.trainable_variables()        
        self.variables_bn = [k for k in self.variables_to_restore if k.name.startswith('batch_normalization')]
        self.variables_herbarium = [k for k in self.variables_to_restore if k.name.startswith('herbarium')]
        self.variables_field = [k for k in self.variables_to_restore if k.name.startswith('field')]

        # ----- New variable list ----- #        
        self.var_list_front = self.variables_herbarium[0:-10] + self.variables_field[0:-10]        
        self.var_list_last = self.variables_herbarium[-10:] + self.variables_field[-10:] + self.variables_bn
        self.var_list_train = self.var_list_front + self.var_list_last
               
        # ----- Network losses ----- #
        with tf.name_scope("loss_calculation"): 
            with tf.name_scope("triplets_loss"):
                self.triplets_loss = tf.reduce_mean(
                        tf.contrib.losses.metric_learning.triplet_semihard_loss(
                                labels=lbl_concat, embeddings=feat_concat, margin=1.0))

            with tf.name_scope("L2_reg_loss"):
                self.regularization_loss = tf.add_n([ tf.nn.l2_loss(v) for v in self.var_list_train]) * 0.00004 
                
            with tf.name_scope("total_loss"):
                self.totalloss = self.triplets_loss + self.regularization_loss
                
        # ----- Create update operation ----- #
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)    
        self.vars_ckpt = slim.get_variables_to_restore()
        
        vars_ckpt_herbarium = [k for k in self.vars_ckpt if k.name.startswith('herbarium')]        
        vars_ckpt_field = [k for k in self.vars_ckpt if k.name.startswith('field')]

        # ----- Restore model 1 ----- #
        restore_fn1 = slim.assign_from_checkpoint_fn(
            self.checkpoint_model1, vars_ckpt_herbarium[:-2]) 
                
        # ----- Restore model 2 ----- #
        restore_fn2 = slim.assign_from_checkpoint_fn(
            self.checkpoint_model2, vars_ckpt_field[:-2]) 
       
        # ----- Training scope ----- #       
        with tf.name_scope("train"):
            loss_accumulator = tf.Variable(0.0, trainable=False)
            self.collect_loss = loss_accumulator.assign_add(self.totalloss)
            self.average_loss = tf.cond(self.is_training,
                                        lambda: loss_accumulator / self.iterbatch,
                                        lambda: loss_accumulator / self.val_iter)
            self.zero_op_loss = tf.assign(loss_accumulator,0.0)
            self.accum_train = [tf.Variable(tf.zeros_like(
                    tv.initialized_value()), trainable=False) for tv in self.var_list_train]                                               
            self.zero_ops_train = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_train]
            
            # ----- Set up optimizer / Compute gradients ----- #
            with tf.control_dependencies(self.update_ops):
                optimizer = tf.train.AdamOptimizer(self.learning_rate * 0.1)                
                optimizer_lastlayers = tf.train.AdamOptimizer(self.learning_rate)                
                gradient1 = optimizer.compute_gradients(self.totalloss,self.var_list_front)
                gradient2 = optimizer_lastlayers.compute_gradients(self.totalloss,self.var_list_last)
                gradient = gradient1 + gradient2              
                gradient_only = [gc[0] for gc in gradient]
                gradient_only,_ = tf.clip_by_global_norm(gradient_only,1.25)
                
                self.accum_train_ops = [self.accum_train[i].assign_add(gc) for i,gc in enumerate(gradient_only)]

            # ----- Apply gradients ----- #
            self.train_step = optimizer.apply_gradients(
                    [(self.accum_train[i], gc[1]) for i, gc in enumerate(gradient)])
            
        # ----- Global variables ----- #
        var_list = tf.trainable_variables()
        g_list = tf.global_variables()
        bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
        bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]             
        var_list += bn_moving_vars
        
        # ----- Create saver ----- #
        saver = tf.train.Saver(var_list=var_list, max_to_keep=0)
        tf.summary.scalar('loss',self.average_loss) 
        self.merged = tf.summary.merge([tf.get_collection(tf.GraphKeys.SUMMARIES,'loss')])
        
        # ----- Tensorboard writer--- #
        writer_train = tf.summary.FileWriter(tensorboard_dir+'/train')
        writer_test = tf.summary.FileWriter(tensorboard_dir+'/test')

        print('Commencing training...')        
        # ----- Create session ----- #
        with tf.Session() as sess:            
            sess.run(tf.global_variables_initializer())
            writer_train.add_graph(sess.graph)
            writer_test.add_graph(sess.graph)
            
            restore_fn1(sess)
            restore_fn2(sess)

            
            for i in range(self.max_iter+1):
                try:
                    sess.run(self.zero_ops_train)
                    sess.run([self.zero_op_loss])                    
                    
                    # ----- Validation ----- #
                    if i % self.val_freq == 0:                        
                        print('Start:%f'%sess.run(loss_accumulator))
                        for j in range(self.val_iter):
                            img1,img2,lbl1, lbl2 = self.test_database.read_batch()
                            sess.run(
                                        self.collect_loss,
                                        feed_dict = {x1 : img1,
                                                     x2 : img2,
                                                     y1 : lbl1,
                                                     y2 : lbl2,
                                                     self.is_training : False,
                                                     is_train : False
                                        }                                  
                                    )
                            print('[%i]:%f'%(j,sess.run(loss_accumulator)))
                        print('End:%f'%sess.run(loss_accumulator))  
                        s,self.netLoss = sess.run(                        
                                [self.merged,self.average_loss],
                                    feed_dict = {
                                            self.is_training : False
                                    }                            
                                ) 
                        writer_test.add_summary(s, i)
                        print('[Valid] Epoch:%i Iter:%i Loss:%f'%(self.train_database.epoch,i,self.netLoss))

                        sess.run([self.zero_op_loss])
                        
                    # ----- Train ----- #
                    for j in range(self.iterbatch):
                        img1,img2,lbl1,lbl2 = self.train_database.read_batch()
    
                        sess.run(
                                    [self.collect_loss,self.accum_train_ops],
                                    feed_dict = {x1 : img1, 
                                                 x2 : img2,
                                                 y1 : lbl1,
                                                 y2 : lbl2,
                                                 self.is_training : True,
                                                 is_train : True
                                    }                                
                                )
                        
                    s,self.netLoss = sess.run(
                            [self.merged,self.average_loss],
                                feed_dict = {
                                        self.is_training : True
                                }                            
                            ) 
                    writer_train.add_summary(s, i)
                    
                    sess.run([self.train_step])
                        
                    print('[Train] Epoch:%i Iter:%i Loss:%f'%(self.train_database.epoch,i,self.netLoss))

                    if i % 5000 == 0:
                        saver.save(sess, os.path.join(self.save_dir,'%06i.ckpt'%i)) 
                    
                except KeyboardInterrupt:
                    print('Interrupt detected. Ending...')
                    break
                
            # ----- Save model --- #
            saver.save(sess, os.path.join(self.save_dir,'final.ckpt')) 
            print('Model saved')