예제 #1
0
def pack_model(inputs, train = True, back_network = 'resnet50',
            #num_classes=81,
            num_classes=91,
            base_anchors=9,
            weight_decay=0.00005,
            **kwargs
        ):

    # Reshape the input image, batch size 1 supported
    image = inputs['images']
    ih = inputs['height']
    iw = inputs['width']
    im_shape = tf.shape(image)
    #image = tf.Print(image, [im_shape], message = 'shape', summarize = 4)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))
    image = tf.cast(image, tf.float32)

    image_height = ih
    image_width = iw
    num_instances = inputs['num_objects']
    gt_boxes = inputs['bboxes']
    #gt_boxes = tf.reshape(gt_boxes, [num_instances, 4])
    #labels = inputs['labels']
    #labels = tf.reshape(labels, [num_instances, 1])
    #gt_boxes = tf.concat([gt_boxes, tf.cast(labels, tf.float64)], 1)
    #gt_boxes = tf.Print(gt_boxes, [tf.shape(gt_boxes)], message = 'Box shape', summarize = 4)

    # Build the basic network
    logits, end_points, pyramid_map = network.get_network(back_network, image,
            weight_decay=weight_decay)

    # Build the pyramid
    pyramid = pyramid_network.build_pyramid(pyramid_map, end_points)

    # Build the heads
    outputs = \
        pyramid_network.build_heads(pyramid, image_height, image_width, num_classes, base_anchors, 
                    is_training=train, gt_boxes=gt_boxes)

    return {'outputs': outputs, 'pyramid': pyramid}, {'network': back_network}
예제 #2
0
def train():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=True)

    data_queue = tf.RandomShuffleQueue(
        capacity=32,
        min_after_dequeue=16,
        dtypes=(image.dtype, ih.dtype, iw.dtype, gt_boxes.dtype,
                gt_masks.dtype, num_instances.dtype, img_id.dtype))
    enqueue_op = data_queue.enqueue(
        (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances,
     img_id) = data_queue.dequeue()
    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network,
        image,
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    outputs = pyramid_network.build(end_points,
                                    im_shape[1],
                                    im_shape[2],
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=gt_boxes,
                                    gt_masks=gt_masks,
                                    loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

    total_loss = outputs['total_loss']
    losses = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    input_image = end_points['input']
    final_box = outputs['final_boxes']['box']
    final_cls = outputs['final_boxes']['cls']
    final_prob = outputs['final_boxes']['prob']
    final_gt_cls = outputs['final_boxes']['gt_cls']
    # Get the computation result from output
    final_mask = outputs['mask']['mask']
    print('final_mask', final_mask.shape)
    gt = outputs['gt']

    # replace the draw_bbox

    # draw_mask(step,
    #               np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
    #               name='est',
    #               bbox=final_boxnp,
    #               mask=final_masknp,
    #               label=final_clsnp,
    #               prob=final_probnp,
    #               gt_label=np.argmax(np.asarray(final_gt_clsnp),axis=1),
    #               )

    # draw_mask(step,
    #               np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
    #               name='gt',
    #               bbox=gtnp[:,0:4],
    #               mask=gt_masksnp,
    #               label=np.asarray(gtnp[:,4], dtype=np.uint8),
    #               )

    #############################
    tmp_0 = outputs['losses']
    tmp_1 = outputs['losses']
    tmp_2 = outputs['losses']
    tmp_3 = outputs['losses']
    tmp_4 = outputs['losses']

    # tmp_0 = outputs['tmp_0']
    # tmp_1 = outputs['tmp_1']
    # tmp_2 = outputs['tmp_2']
    tmp_3 = outputs['tmp_3']
    tmp_4 = outputs['tmp_4']
    ############################

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]
    #os.environ["CUDA_VISIBLE_DEVICES"]="1"
    config = tf.ConfigProto(allow_soft_placement=True)
    #config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.95
    ##gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
    ##sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    sess = tf.Session(config=config)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        #rGet the mask results
        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch, \
        input_imagenp, final_masknp, final_boxnp, final_clsnp, final_probnp, final_gt_clsnp, gtnp,gt_masksnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, tmp_4np= \
                     sess.run([update_op, total_loss, regular_loss, img_id] +
                              losses +
                              [gt_boxes] +
                              batch_info +
                              [input_image] + [final_mask] + [final_box] +  [final_cls] + [final_prob] + [final_gt_cls] + [gt] + [gt_masks] + [tmp_0] + [tmp_1] + [tmp_2] + [tmp_3] + [tmp_4])

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print(
                """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                """instances: %d, """
                """batch:(%d|%d, %d|%d, %d|%d)""" %
                (step, img_id_str, duration_time, reg_lossnp, tot_loss,
                 rpn_box_loss, rpn_cls_loss, refined_box_loss,
                 refined_cls_loss, mask_loss, gt_boxesnp.shape[0],
                 rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch,
                 mask_batch_pos, mask_batch))

            draw_bbox(
                step,
                np.uint8((np.array(input_imagenp[0]) / 2.0 + 0.5) * 255.0),
                name='est',
                bbox=final_boxnp,
                label=final_clsnp,
                prob=final_probnp,
                gt_label=np.argmax(np.asarray(final_gt_clsnp), axis=1),
            )

            draw_bbox(
                step,
                np.uint8((np.array(input_imagenp[0]) / 2.0 + 0.5) * 255.0),
                name='gt',
                bbox=gtnp[:, 0:4],
                label=np.asarray(gtnp[:, 4], dtype=np.uint8),
            )

            print("labels")
            # print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(final_gt_clsnp),axis=1)))[1:])
            # print (cat_id_to_cls_name(np.unique(np.asarray(gt_boxesnp, dtype=np.uint8)[:,4])))
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.asarray(tmp_3np), axis=1)))[1:])
            #print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(gt_boxesnp)[:,4],axis=1))))
            print("classes")
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.array(tmp_4np), axis=1))))
            # print (np.asanyarray(tmp_3np))

            #print ("ordered rois")
            #print (np.asarray(tmp_0np)[0])
            #print ("pyramid_feature")
            #print ()
            #print(np.unique(np.argmax(np.array(final_probnp),axis=1)))
            #for var, val in zip(tmp_2, tmp_2np):
            #    print(var.name)
            #print(np.argmax(np.array(tmp_0np),axis=1))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print(gt_boxesnp)
                raise

        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(
                FLAGS.train_dir,
                FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
예제 #3
0
def train():
    ## set the parameters for different datasets
    if FLAGS.dataset == 'mnist_test':
        img_height = img_width = 28
        learning_rate = 0.001
        Detcoef = 50
        apply_network = 'lenet'
    elif FLAGS.dataset == 'usps':
        img_height = img_width = 16
        learning_rate = 0.0001
        Detcoef = 50
        apply_network = 'lenet0'
    elif FLAGS.dataset == 'frgc':
        img_height = img_width = 32
        learning_rate = 0.1
        Detcoef = 20
        apply_network = 'lenet'
    elif FLAGS.dataset == 'ytf':
        img_height = img_width = 55
        learning_rate = 0.1
        Detcoef = 20
        apply_network = 'lenet'
    elif FLAGS.dataset == 'umist':
        img_height = 112
        img_width = 92
        learning_rate = 0.0001
        Detcoef = 20
        apply_network = 'dlenet'
    else:
        img_height = FLAGS.img_height
        img_width = FLAGS.img_width
        learning_rate = FLAGS.learning_rate
        Detcoef = FLAGS.Detcoef
        apply_network = FLAGS.network

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():
        # tensor for input images
        if FLAGS.is_resize:
            imageip = tf.placeholder(
                tf.float32, [None, FLAGS.resize_height, FLAGS.resize_width, 3])
        else:
            imageip = tf.placeholder(tf.float32,
                                     [None, img_height, img_width, 3])

        # get the embedding data from the network
        _, end_points = network.get_network(apply_network,
                                            imageip,
                                            FLAGS.max_k,
                                            weight_decay=FLAGS.weight_decay,
                                            is_training=True,
                                            reuse=False,
                                            spatial_squeeze=False)
        # fc3 is the name of our embedding layer
        end_net = end_points['fc3']

        # normalize the embedding data
        if FLAGS.normalize == 0:  # standardize
            end_data = standardize(end_net)
        elif FLAGS.normalize == 1:  # batch normalize
            end_data = slim.batch_norm(end_net,
                                       activation_fn=None,
                                       scope='batchnorm',
                                       is_training=True)

        # calculate LD the sample covaraince variance matrix of embedding data
        diff_data = end_data - tf.expand_dims(tf.reduce_mean(end_data, 0), 0)
        cov_data = 1. / (tf.cast(tf.shape(end_data)[0], tf.float32) -
                         1.) * tf.matmul(tf.transpose(diff_data), diff_data)
        det_loss = -logdet(cov_data)

        # get the numpy data for both purpose of clustering and evaluation
        _, val_end_points = network.get_network(
            apply_network,
            imageip,
            FLAGS.max_k,
            weight_decay=FLAGS.weight_decay,
            is_training=False,
            reuse=True,
            spatial_squeeze=False)
        val_end_data = val_end_points['fc3']

        if FLAGS.normalize == 1:
            val_end_data = slim.batch_norm(val_end_data,
                                           activation_fn=None,
                                           scope='batchnorm',
                                           is_training=False,
                                           reuse=True)

        # clustering loss
        cls_mus = tf.placeholder(tf.float32, [None, FLAGS.embed_dims])
        cls_Gammas = tf.placeholder(tf.float32,
                                    [None, FLAGS.embed_dims, FLAGS.embed_dims])
        cluster_loss = gmm_loss(end_data, cls_mus, cls_Gammas)

        # l2 regularization
        penalty = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        # total loss
        total_loss = cluster_loss + Detcoef * det_loss
        if penalty:
            l2_penalty = tf.add_n(penalty)
            total_loss += l2_penalty

        global_step = slim.create_global_step()

        ## load the data
        df_path = '{}/{}.h5'.format(FLAGS.dataset_dir, FLAGS.dataset)
        f = h5py.File(df_path, 'r')
        ## Get the data
        data = list(f['data'])
        label = list(f['labels'])
        train_datum = load_train_data(data, label)
        train_datum.center_data()
        train_datum.shuffle(100)
        val_data, val_truth = np.copy(train_datum.data), np.copy(
            train_datum.label)

        ## set up mini-batch steps and optimizer
        batch_num = train_datum.data.shape[0] // FLAGS.batch_size

        learning_rate = tf.train.inverse_time_decay(learning_rate, global_step,
                                                    batch_num,
                                                    0.0001 * batch_num, True)
        var_list = tf.trainable_variables()

        opt = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                         momentum=FLAGS.momentum)
        train_opt = slim.learning.create_train_op(total_loss,
                                                  opt,
                                                  global_step=global_step,
                                                  variables_to_train=var_list,
                                                  summarize_gradients=False)

        ## load session
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        ## log setting and results
        timestampLaunch = time.strftime("%d%m%Y") + '-' + time.strftime(
            "%H%M%S")
        # record config
        if not os.path.exists(FLAGS.out_dir):
            os.makedirs(FLAGS.out_dir)
        if not os.path.exists(os.path.join(FLAGS.out_dir, FLAGS.dataset)):
            os.makedirs(os.path.join(FLAGS.out_dir, FLAGS.dataset))
        outdir = os.path.join(FLAGS.out_dir, FLAGS.dataset, timestampLaunch)
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        if FLAGS.dataset == 'umist':
            max_periods = 2000
        else:
            max_periods = FLAGS.max_periods
        # load saver and restore session
        saver = tf.train.Saver(max_to_keep=3)
        if FLAGS.restore_previous_if_exists:
            restore(sess, 1)
        else:
            if FLAGS.if_initialize_from_pretrain:
                restore(sess, 0)

        period_cluster_l, period_det_l, period_tot_l, conv_cluster_l = [], [], [], [sys.float_info.max]
        """ start the training """
        print('start training the dataset of {}'.format(FLAGS.dataset))
        for period in range(max_periods):
            real_period = period + FLAGS.checkpoint_periods
            '''Forward steps'''
            ## get the numpy array of embedding data for clustering
            val_embed = []
            if FLAGS.dataset == 'mnist_test':  #10000
                for s in range(10):
                    start = s * 1000
                    end = (s + 1) * 1000
                    val_embed_x = sess.run(
                        val_end_data, feed_dict={imageip: val_data[start:end]})
                    val_embed.append(val_embed_x)
            elif FLAGS.dataset == 'usps':  # 11000
                for s in range(11):
                    start = s * 1000
                    end = (s + 1) * 1000
                    val_embed_x = sess.run(
                        val_end_data, feed_dict={imageip: val_data[start:end]})
                    val_embed.append(val_embed_x)
            elif FLAGS.dataset == 'frgc':  # 2462
                for s in range(25):
                    start = s * 100
                    end = (s + 1) * 100
                    if s == 24:
                        end = end - 38
                    val_embed_x = sess.run(
                        val_end_data, feed_dict={imageip: val_data[start:end]})
                    val_embed.append(val_embed_x)
            elif FLAGS.dataset == 'ytf':  ##55x55; 10000
                for s in range(10):
                    start = s * 1000
                    end = (s + 1) * 1000
                    val_embed_x = sess.run(
                        val_end_data, feed_dict={imageip: val_data[start:end]})
                    val_embed.append(val_embed_x)
            elif FLAGS.dataset == 'umist':  # < 2000
                val_embed = sess.run(val_end_data,
                                     feed_dict={imageip: val_data})
            if FLAGS.dataset != 'umist':
                val_embed = np.concatenate(val_embed, axis=0)

            if FLAGS.normalize == 0:
                val_embed, val_mean, val_std = np_standardize(val_embed)
            ### use dpm to cluster the embedding data
            dpgmm = mixture.BayesianGaussianMixture(
                n_components=FLAGS.max_k,
                weight_concentration_prior=FLAGS.alpha / FLAGS.max_k,
                weight_concentration_prior_type='dirichlet_process',
                covariance_prior=FLAGS.embed_dims *
                np.identity(FLAGS.embed_dims),
                covariance_type='full').fit(val_embed)
            val_labels = dpgmm.predict(val_embed)

            if FLAGS.onsign:
                ### SIGN algorithm to merge clusters
                ulabels = np.unique(val_labels).tolist()
                uln_l = []
                ulxtx_l = []
                ulxx_l = []
                for ul in ulabels:
                    ulx = val_embed[val_labels == ul, :]  #Nk x p
                    uln = np.sum(val_labels == ul)  #Nk
                    ulxtx = np.matmul(ulx.T, ulx)  #p x p
                    ulxx = np.sum(ulx, axis=0)  # p
                    uln_l.append(uln)
                    ulxtx_l.append(ulxtx)
                    ulxx_l.append(ulxx)
                uxx = np.stack(ulxx_l, axis=0)  #kxp
                un = np.array(uln_l)  # k
                uxtx = np.stack(ulxtx_l, axis=0).T  # p x p x k

                if FLAGS.embed_dims < 50:
                    Rest = Gibbs_DPM_Gaussian_summary_input(uxtx, uxx,
                                                            un)  # mcmc
                else:
                    Rest = R_VI_PYMMG_CoC(uxtx, uxx,
                                          un)  # variational inference
                member, dp_Gammas, dp_mus = Rest['member_est'], Rest[
                    'Prec'], Rest['mu']

                val_labels_new = np.copy(val_labels)
                for u, ul in enumerate(ulabels):
                    val_labels_new[val_labels == ul] = int(
                        member[u])  # order the cluster value with index
                val_labels = np.copy(val_labels_new)

                # evaluate and save the results
                val_count = np.bincount(val_labels)
                val_count2 = np.nonzero(val_count)
                est_cls = {}
                for v in val_count2[0].tolist():
                    est_cls[v] = []
                for vv, vl in enumerate(val_labels.tolist()):
                    est_cls[vl].append(val_truth[vv])

                ## sort the labels to be used for backward
                train_labels_new = np.copy(val_labels)
                member1 = np.array([int(m) for m in member])
                member2 = np.unique(member1)
                member2.sort()
                train_labels_new1 = np.copy(train_labels_new)

                for mbi, mb in enumerate(member2.tolist()):
                    train_labels_new1[train_labels_new == mb] = mbi
                train_labels_onehot = np.eye(
                    member2.shape[0])[train_labels_new1]
            else:
                dp_mus = dpgmm.means_
                dp_Gammas = dpgmm.precisions_.T
                train_labels_onehot = np.eye(FLAGS.max_k)[val_labels]

            nmi = normalized_mutual_info_score(val_labels, val_truth)
            if period > 0:
                print("NMI for period{} is {}".format(period, nmi))

            if period >= 100:
                ## check if the results need to be saved using det_loss and cluster_loss
                dperiod_det_loss = np.abs(
                    (period_det_l[-1] - period_det_l[-2]) / period_det_l[-2])
                if dperiod_det_loss <= FLAGS.epsilon:
                    conv_cluster_l.append(period_cluster_loss)
                    if conv_cluster_l[-1] < min(conv_cluster_l[:-1]):
                        best_nmi, best_period = nmi, real_period
                        saver.save(sess, os.path.join(outdir, 'ckpt'),
                                   real_period)
                        # save truth and labels
                        np.savez(os.path.join(
                            outdir, 'labels_{}.npy'.format(real_period)),
                                 val_labels=val_labels,
                                 val_truth=val_truth,
                                 val_mean=val_mean,
                                 val_std=val_std)
                        # save dpm model
                        with open(
                                os.path.join(
                                    outdir,
                                    'model_{}.pkl'.format(real_period)),
                                'wb') as pf:
                            pickle.dump(dpgmm, pf)

            if period < max_periods - 1:
                ''' Backward steps'''
                # require: train_labels_onehot:NxK; dp_mus: KxD; dp_Gammas: DxDxK
                train_datum.reset(
                )  # reset data from the original order to match predicted label
                period_cluster_loss, period_det_loss = 0., 0.
                for step in range(batch_num):
                    real_step = step + real_period * batch_num
                    train_x, train_y = train_datum.nextBatch(FLAGS.batch_size)
                    start, end = step * FLAGS.batch_size, (
                        step + 1) * FLAGS.batch_size
                    step_labels_onehot = train_labels_onehot[start:end]
                    cls_mu = np.matmul(step_labels_onehot,
                                       dp_mus)  # NxK x KxD=> NxD
                    cls_Gamma = np.matmul(
                        dp_Gammas,
                        step_labels_onehot.T).T  # DxDxK KxN => DxDxN => NxDxD
                    _, dlossv, dtlossv = sess.run(
                        [train_opt, cluster_loss, det_loss],
                        feed_dict={
                            imageip: train_x,
                            cls_mus: cls_mu,
                            cls_Gammas: cls_Gamma
                        })

                    # save loss
                    period_cluster_loss += dlossv / batch_num
                    period_det_loss += dtlossv / batch_num
                    #print('DP loss for back step {} is {}; det loss is{}, total loss is{}'.format(real_step,
                    #    dlossv, dtlossv, dlossv + Detcoef*dtlossv))
                ## shuffle train data for next batch
                train_datum.shuffle(period)
                val_data, val_truth = np.copy(train_datum.data), np.copy(
                    train_datum.label)
                ## record the period loss
                period_tot_loss = period_cluster_loss + Detcoef * period_det_loss
                period_det_l.append(period_det_loss)
                period_cluster_l.append(period_cluster_loss)
                period_tot_l.append(period_tot_loss)
def test():
    """The main function that runs training"""

    ## data
    image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id = \
        datasets.get_dataset(FLAGS.dataset_name, 
                             FLAGS.dataset_split_name_test, 
                             FLAGS.dataset_dir, 
                             FLAGS.im_batch,
                             is_training=False)

    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network, image,
            weight_decay=0.0, batch_norm_decay=0.0, is_training=True)
    outputs = pyramid_network.build(end_points, im_shape[1], im_shape[2], pyramid_map,
            num_classes=81,
            base_anchors=3,
            is_training=False,
            gt_boxes=None, gt_masks=None, loss_weights=[0.0, 0.0, 0.0, 0.0, 0.0])

    input_image = end_points['input']

    testing_mask_rois = outputs['mask_ordered_rois']
    testing_mask_final_mask = outputs['mask_final_mask']
    testing_mask_final_clses = outputs['mask_final_clses']
    testing_mask_final_scores = outputs['mask_final_scores']

    ## solvers
    global_step = slim.create_global_step()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    # init_op = tf.group(
    #         tf.global_variables_initializer(),
    #         tf.local_variables_initializer()
    #         )
    # sess.run(init_op)

    # summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)
    tf.train.start_queue_runners(sess=sess)

    ## main loop
    # for step in range(FLAGS.max_iters):
    for step in range(82783):#range(40503):
        
        start_time = time.time()

        image_id_str, original_image_heightnp, original_image_widthnp, image_heightnp, image_widthnp, \
        gt_boxesnp, gt_masksnp,\
        input_imagenp,\
        testing_mask_roisnp, testing_mask_final_masknp, testing_mask_final_clsesnp, testing_mask_final_scoresnp = \
                     sess.run([image_id] + [original_image_height] + [original_image_width] + [image_height] + [image_width] +\
                              [gt_boxes] + [gt_masks] +\
                              [input_image] + \
                              [testing_mask_rois] + [testing_mask_final_mask] + [testing_mask_final_clses] + [testing_mask_final_scores])

        duration_time = time.time() - start_time
        if step % 1 == 0: 
            print ( """iter %d: image-id:%07d, time:%.3f(sec), """
                    """instances: %d, """
                    
                   % (step, image_id_str, duration_time, 
                      gt_boxesnp.shape[0]))

        if step % 1 == 0: 
            draw_bbox(step, 
                      np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0), 
                      name='test_est', 
                      bbox=testing_mask_roisnp, 
                      label=testing_mask_final_clsesnp, 
                      prob=testing_mask_final_scoresnp,
                      mask=testing_mask_final_masknp,
                      vis_th=0.5)

            draw_bbox(step, 
                      np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0), 
                      name='test_gt', 
                      bbox=gt_boxesnp[:,0:4], 
                      label=gt_boxesnp[:,4].astype(np.int32), 
                      prob=np.ones((gt_boxesnp.shape[0],81), dtype=np.float32),)

            print ("predict")
            # LOG (cat_id_to_cls_name(np.unique(np.argmax(np.array(training_rcnn_clsesnp),axis=1))))
            print (cat_id_to_cls_name(testing_mask_final_clsesnp))
            print (np.max(np.array(testing_mask_final_scoresnp),axis=1))

        _collectData(image_id_str, testing_mask_final_clsesnp, testing_mask_roisnp, testing_mask_final_scoresnp, original_image_heightnp, original_image_widthnp, image_heightnp, image_widthnp, testing_mask_final_masknp)
예제 #5
0
def train():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=True)

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network, image, weight_decay=FLAGS.weight_decay)
    outputs = pyramid_network.build(end_points,
                                    ih,
                                    iw,
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=gt_boxes,
                                    gt_masks=gt_masks,
                                    loss_weights=[0.1, 0.2, 1.0, 0.1, 0.5])

    total_loss = outputs['total_loss']
    losses = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch = \
                     sess.run([update_op, total_loss, regular_loss, img_id] +
                              losses +
                              [gt_boxes] +
                              batch_info)

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print(
                """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                """instances: %d, """
                """batch:(%d|%d, %d|%d, %d|%d)""" %
                (step, img_id_str, duration_time, reg_lossnp, tot_loss,
                 rpn_box_loss, rpn_cls_loss, refined_box_loss,
                 refined_cls_loss, mask_loss, gt_boxesnp.shape[0],
                 rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch,
                 mask_batch_pos, mask_batch))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print(gt_boxesnp)
                raise

        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)

        if (step % 1000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.dataset_name + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
예제 #6
0
def train():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=True)

    data_queue = tf.RandomShuffleQueue(
        capacity=32,
        min_after_dequeue=16,
        dtypes=(image.dtype, ih.dtype, iw.dtype, gt_boxes.dtype,
                gt_masks.dtype, num_instances.dtype, img_id.dtype))
    enqueue_op = data_queue.enqueue(
        (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances,
     img_id) = data_queue.dequeue()
    im_shape = tf.shape(image)
    #image = tf.Print(image, [im_shape], message = 'shape', summarize = 4)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network, image, weight_decay=FLAGS.weight_decay)
    outputs = pyramid_network.build(end_points,
                                    ih,
                                    iw,
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=gt_boxes,
                                    gt_masks=gt_masks,
                                    loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

    total_loss = outputs['total_loss']
    losses = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch = \
                     sess.run([update_op, total_loss, regular_loss, img_id] +
                              losses +
                              [gt_boxes] +
                              batch_info )

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print(
                """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                """instances: %d, """
                """batch:(%d|%d, %d|%d, %d|%d)""" %
                (step, img_id_str, duration_time, reg_lossnp, tot_loss,
                 rpn_box_loss, rpn_cls_loss, refined_box_loss,
                 refined_cls_loss, mask_loss, gt_boxesnp.shape[0],
                 rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch,
                 mask_batch_pos, mask_batch))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print(gt_boxesnp)
                raise

        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)

        if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(
                FLAGS.train_dir,
                FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
예제 #7
0
def test():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=False)

    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network,
        image,
        weight_decay=FLAGS.weight_decay,
        is_training=False)
    outputs = pyramid_network.build(end_points,
                                    im_shape[1],
                                    im_shape[2],
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=15,
                                    is_training=False,
                                    gt_boxes=None,
                                    gt_masks=None,
                                    loss_weights=[0.0, 0.0, 0.0, 0.0, 0.0])

    input_image = end_points['input']

    testing_mask_rois = outputs['mask_ordered_rois']
    testing_mask_final_mask = outputs['mask_final_mask']
    testing_mask_final_clses = outputs['mask_final_clses']
    testing_mask_final_scores = outputs['mask_final_scores']

    #############################
    tmp_0 = outputs['tmp_0']
    tmp_1 = outputs['tmp_1']
    tmp_2 = outputs['tmp_2']
    tmp_3 = outputs['tmp_3']
    tmp_4 = outputs['tmp_4']
    tmp_5 = outputs['tmp_5']
    ############################

    ## solvers
    global_step = slim.create_global_step()
    #update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        img_id_str, \
        gt_boxesnp, \
        input_imagenp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, tmp_4np, tmp_5np, \
        testing_mask_roisnp, testing_mask_final_masknp, testing_mask_final_clsesnp, testing_mask_final_scoresnp = \
                     sess.run([img_id] + \
                              [gt_boxes] + \
                              [input_image] + [tmp_0] + [tmp_1] + [tmp_2] + [tmp_3] + [tmp_4] + [tmp_5] + \
                              [testing_mask_rois] + [testing_mask_final_mask] + [testing_mask_final_clses] + [testing_mask_final_scores])

        duration_time = time.time() - start_time
        if step % 1 == 0:
            print("""iter %d: image-id:%07d, time:%.3f(sec), """
                  """instances: %d, """ %
                  (step, img_id_str, duration_time, gt_boxesnp.shape[0]))

        if step % 1 == 0:
            draw_bbox(
                step,
                np.uint8((np.array(input_imagenp[0]) / 2.0 + 0.5) * 255.0),
                name='test_est',
                bbox=testing_mask_roisnp,
                label=testing_mask_final_clsesnp,
                prob=testing_mask_final_scoresnp,
                mask=testing_mask_final_masknp,
            )
예제 #8
0
def train():
    """The main function that runs training"""

    ## data
    # image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
    #     datasets.get_dataset(FLAGS.dataset_name,
    #                          FLAGS.dataset_split_name,
    #                          FLAGS.dataset_dir,
    #                          FLAGS.im_batch,
    #                          is_training=True)

    #cam5_files = glob.glob("/home/mudigonda/files_for_first_maskrcnn_test/CAM5-1-0.25degree_All-Hist_est1_v3_run2.cam.h2.*")

    #Temp_load_image load 1 image with the format (1,height, width)
    #image = temp_load_image("/home/mudigonda/files_for_first_maskrcnn_test/CAM5-1-0.25degree_All-Hist_est1_v3_run2.cam.h2.2012-10-25-00000.nc")

    # ih = np.array(768,dtype='float32')
    # iw = np.array(1152,dtype='float32')
    # gt_boxes = np.load("/home/mudigonda/files_for_first_maskrcnn_test/2012102500_instance_boxes.npy").astype('float32')
    # gt_masks = np.load("/home/mudigonda/files_for_first_maskrcnn_test/2012102500_instance_masks.npy").astype('float32')

    # img_id = np.array(2012102500,dtype='float32')
    # num_instances = np.array([gt_boxes.shape[0]], dtype='float32')

    tfrecords_filename = glob.glob(
        "/home/mudigonda/files_for_first_maskrcnn_test/records/*")
    print(tfrecords_filename)
    #tfrecords_filename = "p"
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        read(tfrecords_filename)

    print(image.shape)
    print(gt_boxes.shape)
    print(gt_masks.shape)

    data_queue = tf.RandomShuffleQueue(
        capacity=5,
        min_after_dequeue=2,
        dtypes=(image.dtype, ih.dtype, iw.dtype, gt_boxes.dtype,
                gt_masks.dtype, num_instances.dtype, img_id.dtype))

    enqueue_op = data_queue.enqueue(
        (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)

    (image, ih, iw, gt_boxes, gt_masks, num_instances,
     img_id) = data_queue.dequeue()
    tf.Print(image, [image])
    im_shape = tf.shape(image)
    #image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 1))
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 1))

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network,
        image,
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    outputs = pyramid_network.build(end_points,
                                    im_shape[1],
                                    im_shape[2],
                                    pyramid_map,
                                    num_classes=3,
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=gt_boxes,
                                    gt_masks=gt_masks,
                                    loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

    import IPython
    IPython.embed()
    print('Pyramid_network successfully built')
    total_loss = outputs['total_loss']
    losses = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    input_image = end_points['input']
    final_box = outputs['final_boxes']['box']
    final_cls = outputs['final_boxes']['cls']
    final_prob = outputs['final_boxes']['prob']
    final_gt_cls = outputs['final_boxes']['gt_cls']
    gt = outputs['gt']

    #############################
    tmp_0 = outputs['losses']
    tmp_1 = outputs['losses']
    tmp_2 = outputs['losses']
    tmp_3 = outputs['losses']
    tmp_4 = outputs['losses']

    # tmp_0 = outputs['tmp_0']
    # tmp_1 = outputs['tmp_1']
    # tmp_2 = outputs['tmp_2']
    tmp_3 = outputs['tmp_3']
    tmp_4 = outputs['tmp_4']
    ############################

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    #restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)
    print("Right before started training in the main loop")

    for step in range(FLAGS.max_iters):

        start_time = time.time()

        print("B4 sess.run in main train loop")
        IPython.embed()
        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch, \
        input_imagenp, final_boxnp, final_clsnp, final_probnp, final_gt_clsnp, gtnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, tmp_4np= \
                     sess.run([update_op, total_loss, regular_loss, img_id] +
                              losses +
                              [gt_boxes] +
                              batch_info +
                              [input_image] + [final_box] + [final_cls] + [final_prob] + [final_gt_cls] + [gt] + [tmp_0] + [tmp_1] + [tmp_2] + [tmp_3] + [tmp_4])

        print("After sess.run in main train loop")
        duration_time = time.time() - start_time
        if step % 1 == 0:
            print(
                """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                """instances: %d, """
                """batch:(%d|%d, %d|%d, %d|%d)""" %
                (step, img_id_str, duration_time, reg_lossnp, tot_loss,
                 rpn_box_loss, rpn_cls_loss, refined_box_loss,
                 refined_cls_loss, mask_loss, gt_boxesnp.shape[0],
                 rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch,
                 mask_batch_pos, mask_batch))

            # draw_bbox(step,
            #           np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
            #           name='est',
            #           bbox=final_boxnp,
            #           label=final_clsnp,
            #           prob=final_probnp,
            #           gt_label=np.argmax(np.asarray(final_gt_clsnp),axis=1),
            #           )

            # draw_bbox(step,
            #           np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0),
            #           name='gt',
            #           bbox=gtnp[:,0:4],
            #           label=np.asarray(gtnp[:,4], dtype=np.uint8),
            #           )

            print("labels")
            # print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(final_gt_clsnp),axis=1)))[1:])
            # print (cat_id_to_cls_name(np.unique(np.asarray(gt_boxesnp, dtype=np.uint8)[:,4])))
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.asarray(tmp_3np), axis=1)))[1:])
            #print (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(gt_boxesnp)[:,4],axis=1))))
            print("classes")
            print(
                cat_id_to_cls_name(
                    np.unique(np.argmax(np.array(tmp_4np), axis=1))))
            # print (np.asanyarray(tmp_3np))

            #print ("ordered rois")
            #print (np.asarray(tmp_0np)[0])
            #print ("pyramid_feature")
            #print ()
            #print(np.unique(np.argmax(np.array(final_probnp),axis=1)))
            #for var, val in zip(tmp_2, tmp_2np):
            #    print(var.name)
            #print(np.argmax(np.array(tmp_0np),axis=1))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print(gt_boxesnp)
                raise

        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(
                FLAGS.train_dir,
                FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
예제 #9
0
def train():
    """The main function that runs training"""
    ## data
    image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id = \
        datasets.get_dataset(FLAGS.dataset_name, 
                             FLAGS.dataset_split_name, 
                             FLAGS.dataset_dir, 
                             FLAGS.im_batch,
                             is_training=True)

    ## queuing data
    data_queue = tf.RandomShuffleQueue(capacity=32, min_after_dequeue=16,
            dtypes=(
                image.dtype, original_image_height.dtype, original_image_width.dtype, image_height.dtype, image_width.dtype,
                gt_boxes.dtype, gt_masks.dtype, 
                num_instances.dtype, image_id.dtype)) 
    enqueue_op = data_queue.enqueue((image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, original_image_height, original_image_width, image_height, image_width, gt_boxes, gt_masks, num_instances, image_id) =  data_queue.dequeue()

    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network, image,
            weight_decay=FLAGS.weight_decay, batch_norm_decay=FLAGS.batch_norm_decay, is_training=True)
    outputs = pyramid_network.build(end_points, image_height, image_width, pyramid_map,
            num_classes=81,
            base_anchors=3,#9#15
            is_training=True,
            gt_boxes=gt_boxes, gt_masks=gt_masks,
            loss_weights=[1.0, 1.0, 10.0, 1.0, 10.0])
            # loss_weights=[10.0, 1.0, 0.0, 0.0, 0.0])
            # loss_weights=[100.0, 100.0, 1000.0, 10.0, 100.0])
            # loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])
            # loss_weights=[0.1, 0.01, 10.0, 0.1, 1.0])

    total_loss = outputs['total_loss']
    losses  = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    input_image = end_points['input']

    training_rcnn_rois                  = outputs['training_rcnn_rois']
    training_rcnn_clses                 = outputs['training_rcnn_clses']
    training_rcnn_clses_target          = outputs['training_rcnn_clses_target'] 
    training_rcnn_scores                = outputs['training_rcnn_scores']
    training_mask_rois                  = outputs['training_mask_rois']
    training_mask_clses_target          = outputs['training_mask_clses_target']
    training_mask_final_mask            = outputs['training_mask_final_mask']
    training_mask_final_mask_target     = outputs['training_mask_final_mask_target']
    tmp_0 = outputs['rpn']['P2']['shape']
    tmp_1 = outputs['rpn']['P3']['shape']
    tmp_2 = outputs['rpn']['P4']['shape']
    tmp_3 = outputs['rpn']['P5']['shape']

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
    #gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    #sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
            )
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## coord settings
    coord = tf.train.Coordinator()
    threads = []
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))
    tf.train.start_queue_runners(sess=sess, coord=coord)

    ## saver init
    saver = tf.train.Saver(max_to_keep=20)

    ## finalize the graph for checking memory leak
    sess.graph.finalize()

    ## main loop
    for step in range(FLAGS.max_iters):
        
        start_time = time.time()

        s_, tot_loss, reg_lossnp, image_id_str, \
        rpn_box_loss, rpn_cls_loss, rcnn_box_loss, rcnn_cls_loss, mask_loss, \
        gt_boxesnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, \
        rpn_batch_pos, rpn_batch, rcnn_batch_pos, rcnn_batch, mask_batch_pos, mask_batch, \
        input_imagenp, \
        training_rcnn_roisnp, training_rcnn_clsesnp, training_rcnn_clses_targetnp, training_rcnn_scoresnp, training_mask_roisnp, training_mask_clses_targetnp, training_mask_final_masknp, training_mask_final_mask_targetnp  = \
                     sess.run([update_op, total_loss, regular_loss, image_id] + 
                              losses + 
                              [gt_boxes] + [tmp_0] + [tmp_1] + [tmp_2] +[tmp_3] +
                              batch_info + 
                              [input_image] + 
                              [training_rcnn_rois] + [training_rcnn_clses] + [training_rcnn_clses_target] + [training_rcnn_scores] + [training_mask_rois] + [training_mask_clses_target] + [training_mask_final_mask] + [training_mask_final_mask_target])

        duration_time = time.time() - start_time
        if step % 1 == 0: 
            LOG ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                    """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                    """instances: %d, """
                    """batch:(%d|%d, %d|%d, %d|%d)""" 
                   % (step, image_id_str, duration_time, reg_lossnp, 
                      tot_loss, rpn_box_loss, rpn_cls_loss, rcnn_box_loss, rcnn_cls_loss, mask_loss,
                      gt_boxesnp.shape[0], 
                      rpn_batch_pos, rpn_batch, rcnn_batch_pos, rcnn_batch, mask_batch_pos, mask_batch))

            LOG ("target")
            LOG (cat_id_to_cls_name(np.unique(np.argmax(np.asarray(training_rcnn_clses_targetnp),axis=1))))
            LOG ("predict")
            LOG (cat_id_to_cls_name(np.unique(np.argmax(np.array(training_rcnn_clsesnp),axis=1))))
            LOG (tmp_0np)
            LOG (tmp_1np)
            LOG (tmp_2np)
            LOG (tmp_3np)

        if step % 50 == 0: 
            draw_bbox(step, 
                      np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0), 
                      name='train_est', 
                      bbox=training_rcnn_roisnp, 
                      label=np.argmax(np.array(training_rcnn_scoresnp),axis=1), 
                      prob=training_rcnn_scoresnp,
                      # bbox=training_mask_roisnp, 
                      # label=training_mask_clses_targetnp, 
                      # prob=np.zeros((training_mask_final_masknp.shape[0],81), dtype=np.float32)+1.0,
                      # mask=training_mask_final_masknp,
                      vis_all=True)

            draw_bbox(step, 
                      np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0), 
                      name='train_gt', 
                      bbox=training_rcnn_roisnp, 
                      label=np.argmax(np.array(training_rcnn_clses_targetnp),axis=1), 
                      prob=np.zeros((training_rcnn_clsesnp.shape[0],81), dtype=np.float32)+1.0,
                      # bbox=training_mask_roisnp, 
                      # label=training_mask_clses_targetnp, 
                      # prob=np.zeros((training_mask_final_masknp.shape[0],81), dtype=np.float32)+1.0,
                      # mask=training_mask_final_mask_targetnp,
                      vis_all=True)
            
            if np.isnan(tot_loss) or np.isinf(tot_loss):
                LOG (gt_boxesnp)
                raise
          
        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 500 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(FLAGS.train_dir, 
                                           FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
        gc.collect()
def forward_test_single_image():
    if not os.path.exists(save_dir_bbox):
        os.makedirs(save_dir_bbox)
    if not os.path.exists(save_dir_mask):
        os.makedirs(save_dir_mask)

    file_pathname = testdata_base_dir + '*.' + file_pattern
    image_paths = glob.glob(file_pathname)  #with .jpg/.png
    image_names = glob.glob(file_pathname)  #no .jpg/.png
    for i in range(len(image_paths)):
        image_names[i] = image_paths[i][len(testdata_base_dir):-4]

    print(image_paths)
    print(image_names)

    TEST_image = tf.placeholder(tf.float32, shape=[1, None, None, 3])
    im_shape = tf.shape(TEST_image)

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network,
        TEST_image,
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    outputs = pyramid_network.build(end_points,
                                    im_shape[1],
                                    im_shape[2],
                                    pyramid_map,
                                    num_classes=81,
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=None,
                                    gt_masks=None,
                                    loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

    input_image = end_points['input']
    print("input_image.shape", input_image.shape)
    final_box = outputs['final_boxes']['box']
    final_cls = outputs['final_boxes']['cls']
    final_prob = outputs['final_boxes']['prob']
    final_mask = outputs['mask']['mask']

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=config)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)

    ## restore trained model
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(
            qr.create_threads(sess, coord=coord, daemon=True, start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(len(image_paths)):
        ## read test image
        test_img = Image.open(image_paths[i])
        test_img = test_img.convert("RGB")
        test_img_i = np.array(test_img, dtype=np.uint8)
        test_img = np.array(test_img, dtype=np.float32)
        test_img = test_img[np.newaxis, ...]
        print("test_img.shape", test_img.shape)
        print("test_img_i.shape", test_img_i.shape)

        # start_time = time.time()

        input_imagenp, final_boxnp, final_clsnp, final_probnp, \
        final_masknp= \
            sess.run([input_image] + [final_box] + [final_cls] + [final_prob] +
                     [final_mask], feed_dict={TEST_image:test_img})

        # duration_time = time.time() - start_time

        draw_bbox(
            test_img_i,
            type='est',
            bbox=final_boxnp,
            label=final_clsnp,
            prob=final_probnp,
            gt_label=None,
            save_dir=save_dir_bbox,
            save_name=image_names[i],
        )

        print("final_masknp.shape\n", final_masknp.shape)

        draw_mask(
            test_img_i,
            type='est',
            bbox=final_boxnp,
            mask=final_masknp,
            label=final_clsnp,
            prob=final_probnp,
            gt_label=None,
            save_dir=save_dir_mask,
            save_name=image_names[i],
        )

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
예제 #11
0
def train():
    """The main function that runs training"""
    ## data
    #this will return the placeholders from tfrecords
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = datasets.get_dataset(FLAGS.dataset_name,  FLAGS.dataset_split_name, FLAGS.dataset_dir, FLAGS.im_batch,is_training=True)

    data_queue = tf.RandomShuffleQueue(capacity=32, min_after_dequeue=16,dtypes=(
                image.dtype, ih.dtype, iw.dtype, 
                gt_boxes.dtype, gt_masks.dtype, 
                num_instances.dtype, img_id.dtype)) 
    enqueue_op = data_queue.enqueue((image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id) =  data_queue.dequeue()
    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network, image,weight_decay=FLAGS.weight_decay, is_training=True)
    outputs = pyramid_network.build(end_points, im_shape[1], im_shape[2], pyramid_map,num_classes=2, base_anchors=9,is_training=True, gt_boxes=gt_boxes, gt_masks=gt_masks,loss_weights=[1.0, 1.0, 1.0, 1.0, 1.0])


    total_loss = outputs['total_loss']
    losses  = outputs['losses']

    regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    
    input_image = end_points['input']
    final_box = outputs['final_boxes']['box']
    final_cls = outputs['final_boxes']['cls']
    final_prob = outputs['final_boxes']['prob']
    final_gt_cls = outputs['final_boxes']['gt_cls']

    #this flag is used for including the mask or not. initally I trained the network without the mask branch, because I wanted to train better the region proposal network
    # so that the network proposes better boxes. If the boxes are better proposed, the branch network will learn easier. Initially I thought that this is the problem
    # for the model memory issue. The idea is that at some point the network was proposing too many regions, like 120, and the Tensor for the mask branch would cause an out of memory error
    # because the shape of tensor would be [120,112,112,7]
    print ("FLAGS INCLUDE MASK IS ",FLAGS.INCLUDE_MASK)
    if FLAGS.INCLUDE_MASK:
        final_mask = outputs['mask']['final_mask_for_drawing']
    gt = outputs['gt']

    

    #############################
    tmp_0 = outputs['losses']
    tmp_1 = outputs['losses']
    tmp_2 = outputs['losses']
    tmp_3 = outputs['tmp_3']
    tmp_4 = outputs['tmp_4']
    ############################


    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
            )
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):
        
        start_time = time.time()
        if FLAGS.INCLUDE_MASK:
            s_, tot_loss, reg_lossnp, img_id_str, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss,mask_loss, gt_boxesnp, input_imagenp, final_boxnp, final_clsnp, final_probnp, final_gt_clsnp, gtnp, tmp_0np, tmp_1np, tmp_2np, tmp_3np, tmp_4np, final_masknp,gt_masksnp= sess.run([update_op, total_loss, regular_loss, img_id] + losses + [gt_boxes] + [input_image] + [final_box] + [final_cls] + [final_prob] + [final_gt_cls] + [gt] + [tmp_0] + [tmp_1] + [tmp_2] + [tmp_3] + [tmp_4]+[final_mask]+[gt_masks])
        else:
            s_, tot_loss, reg_lossnp, img_id_str,\
            rpn_box_loss, rpn_cls_loss,refined_box_loss, refined_cls_loss,\
            gt_boxesnp, input_imagenp, final_boxnp,\
            final_clsnp, final_probnp, final_gt_clsnp, gtnp=\
                sess.run([update_op, total_loss, regular_loss, img_id] +\
                         losses +\
                         [gt_boxes] + [input_image] + [final_box] + \
                         [final_cls] + [final_prob] + [final_gt_cls] + [gt])

        duration_time = time.time() - start_time
        if step % 1 == 0:
            if FLAGS.INCLUDE_MASK:
                print ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.9f, """
                        """total-loss %.10f(%.4f, %.4f, %.6f, %.4f,%.5f), """ #%.4f
                        """instances: %d, proposals: %d """
                       % (step, img_id_str, duration_time, reg_lossnp,
                          tot_loss, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss,
                          gt_boxesnp.shape[0],len(final_boxnp)))
            else:
                print ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.9f, """
                        """total-loss %.4f(%.4f, %.4f, %.6f, %.4f), """ #%.4f
                        """instances: %d, proposals: %d """
                       % (step, img_id_str, duration_time, reg_lossnp,
                          tot_loss, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, #mask_loss,
                          gt_boxesnp.shape[0],len(final_boxnp)))

            if sys.argv[1]=='--draw':
                if FLAGS.INCLUDE_MASK:
                    input_imagenp = np.uint8((np.array(input_imagenp[0])/2.0+0.5)*255.0)
                    final_gt_clsnp = np.argmax(np.asarray(final_gt_clsnp),axis=1)
                    draw_human_body_parts(step, input_imagenp,  bbox=final_boxnp, label=final_clsnp, gt_label=final_gt_clsnp, prob=final_probnp,final_mask=final_masknp)

                else:
                    save(step,input_imagenp,final_boxnp,gt_boxesnp,final_clsnp,final_probnp,final_gt_clsnp,None,None)

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print (gt_boxesnp)
                raise
          
        if step % 1000 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()

        if (step % 1000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(FLAGS.train_dir, 
                                           FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
예제 #12
0
data_queue = tf.RandomShuffleQueue(capacity=32,
                                   min_after_dequeue=16,
                                   dtypes=(image.dtype, ih.dtype, iw.dtype,
                                           gt_boxes.dtype, gt_masks.dtype,
                                           num_instances.dtype, img_id.dtype))
enqueue_op = data_queue.enqueue(
    (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
(image, ih, iw, gt_boxes, gt_masks, num_instances,
 img_id) = data_queue.dequeue()
im_shape = tf.shape(image)
image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

# network
logits, end_points, pyramid_map = network.get_network(
    FLAGS.network, image, weight_decay=FLAGS.weight_decay, is_training=True)
outputs = pyramid_network.build(end_points,
                                im_shape[1],
                                im_shape[2],
                                pyramid_map,
                                num_classes=81,
                                base_anchors=9,
                                is_training=True,
                                gt_boxes=gt_boxes,
                                gt_masks=gt_masks,
                                loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

total_loss = outputs['total_loss']
losses = outputs['losses']
batch_info = outputs['batch_info']
regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
예제 #13
0
def train():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name, 
                             FLAGS.dataset_split_name, 
                             FLAGS.dataset_dir, 
                             FLAGS.im_batch,
                             is_training=True)
    
    data_queue = tf.RandomShuffleQueue(capacity=32, min_after_dequeue=16,
            dtypes=(
                image.dtype, ih.dtype, iw.dtype, 
                gt_boxes.dtype, gt_masks.dtype, 
                num_instances.dtype, img_id.dtype)) 
    enqueue_op = data_queue.enqueue((image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 4)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id) =  data_queue.dequeue()
    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(FLAGS.network, image,
            weight_decay=FLAGS.weight_decay, is_training=True)
    outputs = pyramid_network.build(end_points, im_shape[1], im_shape[2], pyramid_map,
            num_classes=81,
            base_anchors=9,
            is_training=True,
            gt_boxes=gt_boxes, gt_masks=gt_masks,
            loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])


    total_loss = outputs['total_loss'] 
    losses  = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    ## solvers
    global_step = slim.create_global_step()
    update_op = solve(global_step)

    cropped_rois = tf.get_collection('__CROPPED__')[0]
    transposed = tf.get_collection('__TRANSPOSED__')[0]
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
            )
    sess.run(init_op)

    summary_op = tf.summary.merge_all()
    logdir = os.path.join(FLAGS.train_dir, strftime('%Y%m%d%H%M%S', gmtime()))
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    summary_writer = tf.summary.FileWriter(logdir, graph=sess.graph)

    ## restore
    restore(sess)

    ## main loop
    coord = tf.train.Coordinator()
    threads = []
    # print (tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS))
    for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                         start=True))

    tf.train.start_queue_runners(sess=sess, coord=coord)
    saver = tf.train.Saver(max_to_keep=20)

    for step in range(FLAGS.max_iters):
        
        start_time = time.time()

        s_, tot_loss, reg_lossnp, img_id_str, \
        rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss, \
        gt_boxesnp, \
        rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch = \
                     sess.run([update_op, total_loss, regular_loss, img_id] + 
                              losses + 
                              [gt_boxes] + 
                              batch_info )

        duration_time = time.time() - start_time
        if step % 1 == 0: 
            print ( """iter %d: image-id:%07d, time:%.3f(sec), regular_loss: %.6f, """
                    """total-loss %.4f(%.4f, %.4f, %.6f, %.4f, %.4f), """
                    """instances: %d, """
                    """batch:(%d|%d, %d|%d, %d|%d)""" 
                   % (step, img_id_str, duration_time, reg_lossnp, 
                      tot_loss, rpn_box_loss, rpn_cls_loss, refined_box_loss, refined_cls_loss, mask_loss,
                      gt_boxesnp.shape[0], 
                      rpn_batch_pos, rpn_batch, refine_batch_pos, refine_batch, mask_batch_pos, mask_batch))

            if np.isnan(tot_loss) or np.isinf(tot_loss):
                print (gt_boxesnp)
                raise
          
        if step % 100 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)

        if (step % 10000 == 0 or step + 1 == FLAGS.max_iters) and step != 0:
            checkpoint_path = os.path.join(FLAGS.train_dir, 
                                           FLAGS.dataset_name + '_' + FLAGS.network + '_model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

        if coord.should_stop():
            coord.request_stop()
            coord.join(threads)
예제 #14
0
def build_model():
    """The main function that runs training"""

    ## data
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = \
        datasets.get_dataset(FLAGS.dataset_name,
                             FLAGS.dataset_split_name,
                             FLAGS.dataset_dir,
                             FLAGS.im_batch,
                             is_training=True)

    data_queue = tf.RandomShuffleQueue(
        capacity=2,
        min_after_dequeue=1,
        dtypes=(image.dtype, ih.dtype, iw.dtype, gt_boxes.dtype,
                gt_masks.dtype, num_instances.dtype, img_id.dtype))
    enqueue_op = data_queue.enqueue(
        (image, ih, iw, gt_boxes, gt_masks, num_instances, img_id))
    data_queue_runner = tf.train.QueueRunner(data_queue, [enqueue_op] * 1)
    tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, data_queue_runner)
    (image, ih, iw, gt_boxes, gt_masks, num_instances,
     img_id) = data_queue.dequeue()
    '''
    image, ih, iw, gt_boxes, gt_masks, num_instances, img_id = tf.train.shuffle_batch(
        [image, ih, iw, gt_boxes, gt_masks, num_instances, img_id], batch_size=FLAGS.batch_size, num_threads=10,
        capacity = 1000 + 3 * FLAGS.batch_size,
        min_after_dequeue=1000)
    '''

    im_shape = tf.shape(image)
    image = tf.reshape(image, (im_shape[0], im_shape[1], im_shape[2], 3))

    ## network
    logits, end_points, pyramid_map = network.get_network(
        FLAGS.network,
        image,
        weight_decay=FLAGS.weight_decay,
        is_training=True)
    outputs = pyramid_network.build(end_points,
                                    im_shape[1],
                                    im_shape[2],
                                    pyramid_map,
                                    num_classes=len(cls_name),
                                    base_anchors=9,
                                    is_training=True,
                                    gt_boxes=gt_boxes,
                                    gt_masks=gt_masks,
                                    loss_weights=[0.2, 0.2, 1.0, 0.2, 1.0])

    total_loss = outputs['total_loss']
    losses = outputs['losses']
    batch_info = outputs['batch_info']
    regular_loss = tf.add_n(
        tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

    input_image = end_points['input']
    final_box = outputs['final_boxes']['box']
    final_cls = outputs['final_boxes']['cls']
    final_prob = outputs['final_boxes']['prob']
    # for k, v in outputs.iteritems():
    #     try:
    #         print(k, v.keys())
    #     except:
    #         print(k, type(v))
    # print(outputs['pred_masks'])
    final_gt_cls = outputs['final_boxes']['gt_cls']
    gt = outputs['gt']

    #############################
    tmp_0 = outputs['losses']
    tmp_1 = outputs['losses']
    tmp_2 = outputs['losses']
    tmp_3 = outputs['losses']
    tmp_4 = outputs['losses']

    # tmp_0 = outputs['tmp_0']
    # tmp_1 = outputs['tmp_1']
    # tmp_2 = outputs['tmp_2']
    tmp_3 = outputs['tmp_3']
    tmp_4 = outputs['tmp_4']
    ############################

    return outputs, gt_masks, total_loss, regular_loss, img_id, losses, gt_boxes, batch_info, input_image, \
            final_box, final_cls, final_prob, final_gt_cls, gt, tmp_0, tmp_1, tmp_2, tmp_3, tmp_4