def train(max_iter, snapshot, dataset, setname, mu, lr, bs, tfmodel_folder,
          conv5, model_name, stop_iter, pre_emb=False):
    iters_per_log = 100
    data_folder = './' + dataset + '/' + setname + '_batch/'
    data_prefix = dataset + '_' + setname
    snapshot_file = os.path.join(tfmodel_folder, dataset + '_iter_%d.tfmodel')
    if not os.path.isdir(tfmodel_folder):
        os.makedirs(tfmodel_folder)

    cls_loss_avg = 0
    avg_accuracy_all, avg_accuracy_pos, avg_accuracy_neg = 0, 0, 0
    decay = 0.99
    vocab_size = 8803 if dataset == 'referit' else 12112
    emb_name = 'referit' if dataset == 'referit' else 'Gref'

    if pre_emb:
        print("Use pretrained Embeddings.")
        model = get_segmentation_model(model_name, mode='train',
                                       vocab_size=vocab_size, start_lr=lr,
                                       batch_size=bs, conv5=conv5, emb_name=emb_name)
    else:
        model = get_segmentation_model(model_name, mode='train',
                                       vocab_size=vocab_size, start_lr=lr,
                                       batch_size=bs, conv5=conv5)

    weights = './data/weights/deeplab_resnet_init.ckpt'
    print("Loading pretrained weights from {}".format(weights))
    load_var = {var.op.name: var for var in tf.global_variables()
                if var.name.startswith('res') or var.name.startswith('bn') or var.name.startswith('conv1')}

    snapshot_loader = tf.train.Saver(load_var)
    snapshot_saver = tf.train.Saver(max_to_keep=4)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    snapshot_loader.restore(sess, weights)

    im_h, im_w, num_steps = model.H, model.W, model.num_steps
    text_batch = np.zeros((bs, num_steps), dtype=np.float32)
    image_batch = np.zeros((bs, im_h, im_w, 3), dtype=np.float32)
    mask_batch = np.zeros((bs, im_h, im_w, 1), dtype=np.float32)
    valid_idx_batch = np.zeros((bs, 1), dtype=np.int32)

    reader = data_reader.DataReader(data_folder, data_prefix)

    # for time calculate
    last_time = time.time()
    time_avg = MovingAverage()
    for n_iter in range(max_iter):

        for n_batch in range(bs):
            batch = reader.read_batch(is_log=(n_batch == 0 and n_iter % iters_per_log == 0))
            text = batch['text_batch']
            im = batch['im_batch'].astype(np.float32)
            mask = np.expand_dims(batch['mask_batch'].astype(np.float32), axis=2)

            im = im[:, :, ::-1]
            im -= mu

            text_batch[n_batch, ...] = text
            image_batch[n_batch, ...] = im
            mask_batch[n_batch, ...] = mask

            for idx in range(text.shape[0]):
                if text[idx] != 0:
                    valid_idx_batch[n_batch, :] = idx
                    break

        _, cls_loss_val, lr_val, scores_val, label_val = sess.run([model.train_step,
                                                                   model.cls_loss,
                                                                   model.learning_rate,
                                                                   model.pred,
                                                                   model.target],
                                                                  feed_dict={
                                                                      model.words: text_batch,
                                                                      # np.expand_dims(text, axis=0),
                                                                      model.im: image_batch,
                                                                      # np.expand_dims(im, axis=0),
                                                                      model.target_fine: mask_batch,
                                                                      # np.expand_dims(mask, axis=0)
                                                                      model.valid_idx: valid_idx_batch
                                                                  })
        cls_loss_avg = decay * cls_loss_avg + (1 - decay) * cls_loss_val

        # Accuracy
        accuracy_all, accuracy_pos, accuracy_neg = compute_accuracy(scores_val, label_val)
        avg_accuracy_all = decay * avg_accuracy_all + (1 - decay) * accuracy_all
        avg_accuracy_pos = decay * avg_accuracy_pos + (1 - decay) * accuracy_pos
        avg_accuracy_neg = decay * avg_accuracy_neg + (1 - decay) * accuracy_neg

        # timing
        cur_time = time.time()
        elapsed = cur_time - last_time
        last_time = cur_time

        if n_iter % iters_per_log == 0:
            print('iter = %d, loss (cur) = %f, loss (avg) = %f, lr = %f'
                  % (n_iter, cls_loss_val, cls_loss_avg, lr_val))
            print('iter = %d, accuracy (cur) = %f (all), %f (pos), %f (neg)'
                  % (n_iter, accuracy_all, accuracy_pos, accuracy_neg))
            print('iter = %d, accuracy (avg) = %f (all), %f (pos), %f (neg)'
                  % (n_iter, avg_accuracy_all, avg_accuracy_pos, avg_accuracy_neg))
            time_avg.add(elapsed)
            print('iter = %d, cur time = %.5f, avg time = %.5f, model_name: %s' % (n_iter, elapsed, time_avg.get_avg(), model_name))

        # Save snapshot
        if (n_iter + 1) % snapshot == 0 or (n_iter + 1) >= max_iter:
            snapshot_saver.save(sess, snapshot_file % (n_iter + 1))
            print('snapshot saved to ' + snapshot_file % (n_iter + 1))
        if (n_iter + 1) >= stop_iter:
            print('stop training at iter ' + str(stop_iter))
            break

    print('Optimization done.')
def test(iter, dataset, visualize, setname, dcrf, mu, tfmodel_path, model_name, pre_emb=False):
    data_folder = './' + dataset + '/' + setname + '_batch/'
    data_prefix = dataset + '_' + setname
    if visualize:
        save_dir = './' + dataset + '/visualization/' + str(iter) + '/'
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
    weights = os.path.join(tfmodel_path)
    print("Loading trained weights from {}".format(weights))

    score_thresh = 1e-9
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    cum_I, cum_U = 0, 0
    mean_IoU, mean_dcrf_IoU = 0, 0
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    if dcrf:
        cum_I_dcrf, cum_U_dcrf = 0, 0
        seg_correct_dcrf = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0.
    T = 20 # truncated long sentence
    H, W = 320, 320
    vocab_size = 8803 if dataset == 'referit' else 12112
    emb_name = 'referit' if dataset == 'referit' else 'refvos'
    vocab_file = './data/vocabulary_refvos.txt'
    vocab_dict = text_processing.load_vocab_dict_from_file(vocab_file)
    IU_result = list()

    if pre_emb:
        # use pretrained embbeding
        print("Use pretrained Embeddings.")
        model = get_segmentation_model(model_name, H=H, W=W,
                                       mode='eval', 
                                       vocab_size=vocab_size, 
                                       emb_name=emb_name, 
                                       emb_dir=args.embdir)
    else:
        model = get_segmentation_model(model_name, H=H, W=W,
                                       mode='eval', vocab_size=vocab_size)

    # Load pretrained model
    snapshot_restorer = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    snapshot_restorer.restore(sess, weights)
     
    meta_expression = {}
    with open(args.meta) as meta_file:
        meta_expression = json.load(meta_file)
    videos = meta_expression['videos']
    plt.figure(figsize=[15, 4])
    sorted_video_key = ['a9f23c9150', '6cc8bce61a', '03fe6115d4', 'a46012c642', 'c42fdedcdd', 'ee9415c553', '7daa6343e6', '4fe6619a47', '0e8a6b63bb', '65e0640a2a', '8939473ea7', 'b05faf54f7', '5d2020eff8', 'a00c3fa88e', '44e5d1a969', 'deed0ab4fc', 'b205d868e6', '48d2909d9e', 'c9ef04fe59', '1e20ceafae', '0f3f8b2b2f', 'b83923fd72', 'cb06f84b6e', '17cba76927', '35d5e5149d', '62bf7630b3', '0390fabe58', 'bf2d38aefe', '8b7b57b94d', '8d803e87f7', 'c16d9a4ade', '1a1dbe153e', 'd975e5f4a9', '226f1e10f7', '6cb5b08d93', '77df215672', '466734bc5c', '94fa9bd3b5', 'f2a45acf1c', 'ba8823f2d2', '06cd94d38d', 'b772ac822a', '246e38963b', 'b5514f75d8', '188cb4e03d', '3dd327ab4e', '8e2e5af6a8', '450bd2e238', '369919ef49', 'a4bce691c6', '64c6f2ed76', '0782a6df7e', '0062f687f1', 'c74fc37224', 'f7255a57d0', '4f5b3310e3', 'e027ebc228', '30fe0ed0ce', '6a75316e99', 'a2948d4116', '8273b59141', 'abae1ce57d', '621487be65', '45dc90f558', '9787f452bf', 'cdcfd9f93a', '4f6662e4e0', '853ca85618', '13ca7bbcfd', 'f143fede6f', '92fde455eb', '0b0c90e21a', '5460cc540a', '182dbfd6ba', '85968ae408', '541ccb0844', '43115c42b2', '65350fd60a', 'eb49ce8027', 'e11254d3b9', '20a93b4c54', 'a0fc95d8fc', '696e01387c', 'fef7e84268', '72d613f21a', '8c60938d92', '975be70866', '13c3cea202', '4ee0105885', '01c88b5b60', '33e8066265', '8dea7458de', 'c280d21988', 'fd8cf868b2', '35948a7fca', 'e10236eb37', 'a1251195e7', 'b2256e265c', '2b904b76c9', '1ab5f4bbc5', '47d01d34c8', 'd7a38bf258', '1a609fa7ee', '218ac81c2d', '9f16d17e42', 'fb104c286f', 'eb263ef128', '37b4ec2e1a', '0daaddc9da', 'cd69993923', '31d3a7d2ee', '60362df585', 'd7ff44ea97', '623d24ce2b', '6031809500', '54526e3c66', '0788b4033d', '3f4bacb16a', '06a5dfb511', '9f21474aca', '7a19a80b19', '9a38b8e463', '822c31928a', 'd1ac0d8b81', 'eea1a45e49', '9f429af409', '33c8dcbe09', '9da2156a73', '3be852ed44', '3674b2c70a', '547416bda1', '4037d8305d', '29c06df0f2', '1335b16cf9', 'b7b7e52e02', 'bc9ba8917e', 'dab44991de', '9fd2d2782b', 'f054e28786', 'b00ff71889', 'eeb18f9d47', '559a611d86', 'dea0160a12', '257f7fd5b8', 'dc197289ef', 'c2bbd6d121', 'f3678388a7', '332dabe378', '63883da4f5', 'b90f8c11db', 'dce363032d', '411774e9ff', '335fc10235', '7775043b5e', '3e03f623bb', '19cde15c4b', 'bf4cc89b18', '1a894a8f98', 'f7d7fb16d0', '61fca8cbf1', 'd69812339e', 'ab9a7583f1', 'e633eec195', '0a598e18a8', 'b3b92781d9', 'cd896a9bee', 'b7928ea5c0', '69c0f7494e', 'cc1a82ac2a', '39b7491321', '352ad66724', '749f1abdf9', '7f26b553ae', '0c04834d61', 'd1dd586cfd', '3b72dc1941', '39bce09d8d', 'cbea8f6bea', 'cc7c3138ff', 'd59c093632', '68dab8f80c', '1e0257109e', '4307020e0f', '4b783f1fc5', 'ebe7138e58', '1f390d22ea', '7a72130f21', 'aceb34fcbe', '9c0b55cae5', 'b58a97176b', '152fe4902a', 'a806e58451', '9ce299a510', '97b38cabcc', 'f39c805b54', '0620b43a31', '0723d7d4fe', '7741a0fbce', '7836afc0c2', 'a7462d6aaf', '34564d26d8', '31e0beaf99']
    # sorted_video_key = ['6cc8bce61a']
    for vid_ind, vid in enumerate(sorted_video_key):
        print("Running on video {}/{}".format(vid_ind + 1, len(videos.keys())))
        expressions = videos[vid]['expressions']
        # instance_ids = [expression['obj_id'] for expression_id in videos[vid]['expressions']]
        frame_ids = videos[vid]['frames']
        for eid in expressions:
            exp = expressions[eid]['exp']
            index = int(eid)
            vis_dir = args.visdir
#             mask_dir = os.path.join(args.maskdir, str('{}/{}/'.format(vid, index)))
            if not os.path.exists(vis_dir):
                os.makedirs(vis_dir)
#             if not os.path.exists(mask_dir):
#                 os.makedirs(mask_dir)
            avg_time = 0
            total_frame = 0
#             Process text
            text = np.array(text_processing.preprocess_sentence(exp, vocab_dict, T))
            valid_idx = np.zeros([1], dtype=np.int32)
            for idx in range(text.shape[0]):
                if text[idx] != 0:
                    valid_idx[0] = idx
                    break
            for fid in frame_ids:
                frame_id = int(fid)
                if (frame_id % 20 != 0):
                    continue
                vis_path = os.path.join(vis_dir, str('{}_{}_{}.png'.format(vid,eid,fid)))
                frame = load_frame_from_id(vid, fid)
                if frame is None:
                    continue
                last_time = time.time()
#                 im = frame.copy()
                im = frame
#                 mask = np.array(frame, dtype=np.float32)

                proc_im = skimage.img_as_ubyte(im_processing.resize_and_pad(im, H, W))
                proc_im_ = proc_im.astype(np.float32)
                proc_im_ = proc_im_[:, :, ::-1]
                proc_im_ -= mu
                scores_val, up_val, sigm_val, up_c4 = sess.run([model.pred, 
                                                                                model.up, 
                                                                                model.sigm, 
                                                                                model.up_c4, 
                                                                                ],
                                                                                feed_dict={
                                                                                    model.words: np.expand_dims(text, axis=0),
                                                                                    model.im: np.expand_dims(proc_im_, axis=0),
                                                                                    model.valid_idx: np.expand_dims(valid_idx, axis=0)
                                                                                })
                # scores_val = np.squeeze(scores_val)
                # pred_raw = (scores_val >= score_thresh).astype(np.float32)
                up_c4 = im_processing.resize_and_crop(sigmoid(np.squeeze(up_c4)), frame.shape[0], frame.shape[1])
                sigm_val = im_processing.resize_and_crop(sigmoid(np.squeeze(sigm_val)), frame.shape[0], frame.shape[1])
                up_val = np.squeeze(up_val)
                # if (not math.isnan(consitency_score) and consitency_score < 0.3):
                plt.clf()
                plt.subplot(1, 3, 1)
                plt.imshow(frame)
                plt.text(-0.7, -0.7, exp + str(consitency_score))
                plt.subplot(1, 3, 2)
                plt.imshow(up_c4)
                plt.subplot(1, 3, 3)
                plt.imshow(sigm_val)
                plt.savefig(vis_path)
#                 pred_raw = (up_val >= score_thresh).astype('uint8') * 255
#                 pred_raw = (up_val >= score_thresh).astype(np.float32)
#                 predicts = im_processing.resize_and_crop(pred_raw, mask.shape[0], mask.shape[1])
#                 if dcrf:
#                     # Dense CRF post-processing
#                     sigm_val = np.squeeze(sigm_val) + 1e-7
#                     d = densecrf.DenseCRF2D(W, H, 2)
#                     U = np.expand_dims(-np.log(sigm_val), axis=0)
#                     U_ = np.expand_dims(-np.log(1 - sigm_val), axis=0)
#                     unary = np.concatenate((U_, U), axis=0)
#                     unary = unary.reshape((2, -1))
#                     d.setUnaryEnergy(unary)
#                     d.addPairwiseGaussian(sxy=3, compat=3)
#                     d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=proc_im, compat=10)
#                     Q = d.inference(5)
#                     pred_raw_dcrf = np.argmax(Q, axis=0).reshape((H, W)).astype('uint8') * 255
# #                     pred_raw_dcrf = np.argmax(Q, axis=0).reshape((H, W)).astype(np.float32)
# #                     predicts_dcrf = im_processing.resize_and_crop(pred_raw_dcrf, mask.shape[0], mask.shape[1])
#                 if visualize:
#                     if dcrf:
#                         cv2.imwrite(vis_path, pred_raw_dcrf)
# #                         np.save(mask_path, np.array(pred_raw_dcrf))
# #                         visualize_seg(vis_path, im, exp, predicts_dcrf)
#                     else:
#                         np.save(mask_path, np.array(sigm_val))
#                         cv2.imwrite(vis_path, pred_raw)
#                         visualize_seg(vis_path, im, exp, predicts)
#                         np.save(mask_path, np.array(pred_raw))
    # I, U = eval_tools.compute_mask_IU(predicts, mask)
    # IU_result.append({'batch_no': n_iter, 'I': I, 'U': U})
    # mean_IoU += float(I) / U
    # cum_I += I
    # cum_U += U
    # msg = 'cumulative IoU = %f' % (cum_I / cum_U)
    # for n_eval_iou in range(len(eval_seg_iou_list)):
    #     eval_seg_iou = eval_seg_iou_list[n_eval_iou]
    #     seg_correct[n_eval_iou] += (I / U >= eval_seg_iou)
    # if dcrf:
    #     I_dcrf, U_dcrf = eval_tools.compute_mask_IU(predicts_dcrf, mask)
    #     mean_dcrf_IoU += float(I_dcrf) / U_dcrf
    #     cum_I_dcrf += I_dcrf
    #     cum_U_dcrf += U_dcrf
    #     msg += '\tcumulative IoU (dcrf) = %f' % (cum_I_dcrf / cum_U_dcrf)
    #     for n_eval_iou in range(len(eval_seg_iou_list)):
    #         eval_seg_iou = eval_seg_iou_list[n_eval_iou]
    #         seg_correct_dcrf[n_eval_iou] += (I_dcrf / U_dcrf >= eval_seg_iou)
    # print(msg)
    seg_total += 1
Ejemplo n.º 3
0
def test(iter,
         dataset,
         visualize,
         setname,
         dcrf,
         mu,
         tfmodel_folder,
         model_name,
         pre_emb=False):
    data_folder = dataset + '/' + setname + '_batch/'
    data_prefix = dataset + '_' + setname
    if visualize:
        save_dir = './' + dataset + '/visualization/' + str(iter) + '/'
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
    weights = os.path.join(tfmodel_folder,
                           dataset + '_iter_' + str(iter) + '.tfmodel')

    score_thresh = 1e-9
    eval_seg_iou_list = [.5, .55, .6, .65, .7, .75, .8, .85, .9, .95]
    cum_I, cum_U = 0, 0
    mean_IoU, mean_dcrf_IoU = 0, 0
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    if dcrf:
        cum_I_dcrf, cum_U_dcrf = 0, 0
        seg_correct_dcrf = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0.
    H, W = 320, 320
    vocab_size = 8803 if dataset == 'referit' else 12112
    emb_name = 'referit' if dataset == 'referit' else 'Gref'

    IU_result = list()

    if pre_emb:
        print("Use pretrained embeddings.")
        model = get_segmentation_model(model_name,
                                       H=H,
                                       W=W,
                                       mode='eval',
                                       vocab_size=vocab_size,
                                       emb_name=emb_name)
    else:
        model = get_segmentation_model(model_name,
                                       H=H,
                                       W=W,
                                       mode='eval',
                                       vocab_size=vocab_size)

    # Load pretrained model
    snapshot_restorer = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    snapshot_restorer.restore(sess, weights)
    print("loading trained weights from {}".format(weights))
    reader = data_reader.DataReader(data_folder, data_prefix, shuffle=False)

    NN = reader.num_batch
    for n_iter in range(reader.num_batch):

        if n_iter % (NN // 50) == 0:
            if n_iter / (NN // 50) % 5 == 0:
                sys.stdout.write(str(n_iter / (NN // 50) // 5))
            else:
                sys.stdout.write('.')
            sys.stdout.flush()

        batch = reader.read_batch(is_log=False)
        text = batch['text_batch']
        im = batch['im_batch']
        mask = batch['mask_batch'].astype(np.float32)
        valid_idx = np.zeros([1], dtype=np.int32)
        frames = batch['frames']
        for idx in range(text.shape[0]):
            if text[idx] != 0:
                valid_idx[0] = idx
                break

        proc_im = skimage.img_as_ubyte(im_processing.resize_and_pad(im, H, W))
        proc_im_ = proc_im.astype(np.float32)
        proc_im_ = proc_im_[:, :, ::-1]
        proc_im_ -= mu

        proc_frames = list()
        for i in range(frames.shape[0]):
            proc_frame = skimage.img_as_ubyte(
                im_processing.resize_and_pad(frames[i, :, :, :], H, W))
            proc_frame = proc_frame.astype(np.float32)
            proc_frame = proc_frame[:, :, ::-1]
            proc_frame -= mu
            proc_frames.append(proc_frame)
        proc_frames = np.array(proc_frames, dtype=np.float32)

        scores_val, up_val, sigm_val = sess.run(
            [model.pred, model.up, model.sigm],
            feed_dict={
                model.words: np.expand_dims(text, axis=0),
                model.im: np.expand_dims(proc_im_, axis=0),
                model.valid_idx: np.expand_dims(valid_idx, axis=0),
                model.clip: np.expand_dims(proc_frames, axis=0)
            })

        # scores_val = np.squeeze(scores_val)
        # pred_raw = (scores_val >= score_thresh).astype(np.float32)
        up_val = np.squeeze(up_val)
        pred_raw = (up_val >= score_thresh).astype(np.float32)
        predicts = im_processing.resize_and_crop(pred_raw, mask.shape[0],
                                                 mask.shape[1])
        if dcrf:
            # Dense CRF post-processing
            sigm_val = np.squeeze(sigm_val)
            d = densecrf.DenseCRF2D(W, H, 2)
            U = np.expand_dims(-np.log(sigm_val), axis=0)
            U_ = np.expand_dims(-np.log(1 - sigm_val), axis=0)
            unary = np.concatenate((U_, U), axis=0)
            unary = unary.reshape((2, -1))
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=3, compat=3)
            d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=proc_im, compat=10)
            Q = d.inference(5)
            pred_raw_dcrf = np.argmax(Q, axis=0).reshape(
                (H, W)).astype(np.float32)
            predicts_dcrf = im_processing.resize_and_crop(
                pred_raw_dcrf, mask.shape[0], mask.shape[1])

        if visualize:
            sent = batch['sent_batch'][0]
            visualize_seg(im, mask, predicts, sent)
            if dcrf:
                visualize_seg(im, mask, predicts_dcrf, sent)

        I, U = eval_tools.compute_mask_IU(predicts, mask)

        # deal with empty gt mask
        eps = 1e-7
        if U == eps:
            print("empty gt mask in testing")
            continue

        IU_result.append({'batch_no': n_iter, 'I': I, 'U': U})
        mean_IoU += float(I) / U
        cum_I += I
        cum_U += U
        msg = 'cumulative IoU = %f' % (cum_I / cum_U)
        for n_eval_iou in range(len(eval_seg_iou_list)):
            eval_seg_iou = eval_seg_iou_list[n_eval_iou]
            seg_correct[n_eval_iou] += (I / U >= eval_seg_iou)
        if dcrf:
            I_dcrf, U_dcrf = eval_tools.compute_mask_IU(predicts_dcrf, mask)
            mean_dcrf_IoU += float(I_dcrf) / U_dcrf
            cum_I_dcrf += I_dcrf
            cum_U_dcrf += U_dcrf
            msg += '\tcumulative IoU (dcrf) = %f' % (cum_I_dcrf / cum_U_dcrf)
            for n_eval_iou in range(len(eval_seg_iou_list)):
                eval_seg_iou = eval_seg_iou_list[n_eval_iou]
                seg_correct_dcrf[n_eval_iou] += (I_dcrf / U_dcrf >=
                                                 eval_seg_iou)
        # print(msg)
        seg_total += 1

    # Print results
    print('Segmentation evaluation (without DenseCRF):')
    result_str = ''
    for n_eval_iou in range(len(eval_seg_iou_list)):
        result_str += 'precision@%s = %f\n' % \
                      (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] / seg_total)
    result_str += 'overall IoU = %f; mean IoU = %f\n' % (cum_I / cum_U,
                                                         mean_IoU / seg_total)
    print(result_str)
    if dcrf:
        print('Segmentation evaluation (with DenseCRF):')
        result_str = ''
        for n_eval_iou in range(len(eval_seg_iou_list)):
            result_str += 'precision@%s = %f\n' % \
                          (str(eval_seg_iou_list[n_eval_iou]), seg_correct_dcrf[n_eval_iou] / seg_total)
        result_str += 'overall IoU = %f; mean IoU = %f\n' % (
            cum_I_dcrf / cum_U_dcrf, mean_dcrf_IoU / seg_total)
        print(result_str)
Ejemplo n.º 4
0
def test(iter,
         dataset,
         visualize,
         setname,
         dcrf,
         mu,
         tfmodel_folder,
         model_name,
         pre_emb=False):
    data_folder = './' + dataset + '/' + setname + '_batch/'
    data_prefix = dataset + '_' + setname
    if visualize:
        save_dir = './' + dataset + '/visualization/' + str(iter) + '/'
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
    weights = os.path.join(tfmodel_folder,
                           dataset + '_iter_' + str(iter) + '.tfmodel')
    print("Loading trained weights from {}".format(weights))

    score_thresh = 1e-9
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    cum_I, cum_U = 0, 0
    mean_IoU, mean_dcrf_IoU = 0, 0
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    if dcrf:
        cum_I_dcrf, cum_U_dcrf = 0, 0
        seg_correct_dcrf = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0.
    T = 20  # truncated long sentence
    H, W = 320, 320
    vocab_size = 8803 if dataset == 'referit' else 12112
    emb_name = 'referit' if dataset == 'referit' else 'Gref'
    vocab_file = './data/vocabulary_Gref.txt'
    vocab_dict = text_processing.load_vocab_dict_from_file(vocab_file)
    IU_result = list()

    if pre_emb:
        # use pretrained embbeding
        print("Use pretrained Embeddings.")
        model = get_segmentation_model(model_name,
                                       H=H,
                                       W=W,
                                       mode='eval',
                                       vocab_size=vocab_size,
                                       emb_name=emb_name,
                                       emb_dir=args.embdir)
    else:
        model = get_segmentation_model(model_name,
                                       H=H,
                                       W=W,
                                       mode='eval',
                                       vocab_size=vocab_size)

    # Load pretrained model
    snapshot_restorer = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    snapshot_restorer.restore(sess, weights)

    meta_expression = {}
    with open(args.meta) as meta_file:
        meta_expression = json.load(meta_file)
    videos = meta_expression['videos']
    for vid_ind, vid in reversed(list(enumerate(videos.keys()))):
        print("Running on video {}/{}".format(vid_ind + 1, len(videos.keys())))
        expressions = videos[vid]['expressions']
        # instance_ids = [expression['obj_id'] for expression_id in videos[vid]['expressions']]
        frame_ids = videos[vid]['frames']
        for eid in expressions:
            exp = expressions[eid]['exp']
            index = int(eid)
            vis_dir = os.path.join(args.visdir,
                                   str('{}/{}/'.format(vid, index)))
            mask_dir = os.path.join(args.maskdir,
                                    str('{}/{}/'.format(vid, index)))
            if not os.path.exists(vis_dir):
                os.makedirs(vis_dir)
            if not os.path.exists(mask_dir):
                os.makedirs(mask_dir)
            avg_time = 0
            total_frame = 0
            #             Process text
            text = np.array(
                text_processing.preprocess_sentence(exp, vocab_dict, T))
            valid_idx = np.zeros([1], dtype=np.int32)
            for idx in range(text.shape[0]):
                if text[idx] != 0:
                    valid_idx[0] = idx
                    break
            for fid in frame_ids:
                vis_path = os.path.join(vis_dir, str('{}.png'.format(fid)))
                mask_path = os.path.join(mask_dir, str('{}.npy'.format(fid)))
                if os.path.exists(vis_path):
                    continue
                frame = load_frame_from_id(vid, fid)
                if frame is None:
                    continue
                last_time = time.time()
                #                 im = frame.copy()
                im = frame
                #                 mask = np.array(frame, dtype=np.float32)

                proc_im = skimage.img_as_ubyte(
                    im_processing.resize_and_pad(im, H, W))
                proc_im_ = proc_im.astype(np.float32)
                # proc_im_ = proc_im_[:, :, ::-1]
                proc_im_ -= mu
                scores_val, up_val, sigm_val = sess.run(
                    [model.pred, model.up, model.sigm],
                    feed_dict={
                        model.words: np.expand_dims(text, axis=0),
                        model.im: np.expand_dims(proc_im_, axis=0),
                        model.valid_idx: np.expand_dims(valid_idx, axis=0)
                    })
                # scores_val = np.squeeze(scores_val)
                # pred_raw = (scores_val >= score_thresh).astype(np.float32)
                up_val = np.squeeze(up_val)
                pred_raw = (up_val >= score_thresh).astype('uint8') * 255
                #                 pred_raw = (up_val >= score_thresh).astype(np.float32)
                #                 predicts = im_processing.resize_and_crop(pred_raw, mask.shape[0], mask.shape[1])
                if dcrf:
                    # Dense CRF post-processing
                    sigm_val = np.squeeze(sigm_val) + 1e-7
                    d = densecrf.DenseCRF2D(W, H, 2)
                    U = np.expand_dims(-np.log(sigm_val), axis=0)
                    U_ = np.expand_dims(-np.log(1 - sigm_val), axis=0)
                    unary = np.concatenate((U_, U), axis=0)
                    unary = unary.reshape((2, -1))
                    d.setUnaryEnergy(unary)
                    d.addPairwiseGaussian(sxy=3, compat=3)
                    d.addPairwiseBilateral(sxy=20,
                                           srgb=3,
                                           rgbim=proc_im,
                                           compat=10)
                    Q = d.inference(5)
                    pred_raw_dcrf = np.argmax(Q, axis=0).reshape(
                        (H, W)).astype('uint8') * 255
#                     pred_raw_dcrf = np.argmax(Q, axis=0).reshape((H, W)).astype(np.float32)
#                     predicts_dcrf = im_processing.resize_and_crop(pred_raw_dcrf, mask.shape[0], mask.shape[1])
                if visualize:
                    if dcrf:
                        cv2.imwrite(vis_path, pred_raw_dcrf)
#                         np.save(mask_path, np.array(pred_raw_dcrf))
#                         visualize_seg(vis_path, im, exp, predicts_dcrf)
                    else:
                        np.save(mask_path, np.array(sigm_val))


#                         cv2.imwrite(vis_path, pred_raw)
#                         visualize_seg(vis_path, im, exp, predicts)
#                         np.save(mask_path, np.array(pred_raw))
# I, U = eval_tools.compute_mask_IU(predicts, mask)
# IU_result.append({'batch_no': n_iter, 'I': I, 'U': U})
# mean_IoU += float(I) / U
# cum_I += I
# cum_U += U
# msg = 'cumulative IoU = %f' % (cum_I / cum_U)
# for n_eval_iou in range(len(eval_seg_iou_list)):
#     eval_seg_iou = eval_seg_iou_list[n_eval_iou]
#     seg_correct[n_eval_iou] += (I / U >= eval_seg_iou)
# if dcrf:
#     I_dcrf, U_dcrf = eval_tools.compute_mask_IU(predicts_dcrf, mask)
#     mean_dcrf_IoU += float(I_dcrf) / U_dcrf
#     cum_I_dcrf += I_dcrf
#     cum_U_dcrf += U_dcrf
#     msg += '\tcumulative IoU (dcrf) = %f' % (cum_I_dcrf / cum_U_dcrf)
#     for n_eval_iou in range(len(eval_seg_iou_list)):
#         eval_seg_iou = eval_seg_iou_list[n_eval_iou]
#         seg_correct_dcrf[n_eval_iou] += (I_dcrf / U_dcrf >= eval_seg_iou)
# print(msg)
    seg_total += 1
Ejemplo n.º 5
0
def train(max_iter,
          snapshot,
          dataset,
          data_dir,
          setname,
          mu,
          lr,
          bs,
          tfmodel_folder,
          conv5,
          model_name,
          stop_iter,
          last_iter,
          pre_emb=False,
          finetune=False,
          pretrain_path='',
          emb_dir=''):
    global args
    iters_per_log = 100
    data_folder = os.path.join(data_dir, dataset + '/' + setname + '_batch/')
    data_prefix = dataset + '_' + setname
    snapshot_file = os.path.join(tfmodel_folder, dataset + '_finetune')
    if not os.path.isdir(tfmodel_folder):
        os.makedirs(tfmodel_folder)

    cls_loss_avg = 0
    avg_accuracy_all, avg_accuracy_pos, avg_accuracy_neg = 0, 0, 0
    decay = 0.99
    vocab_size = 8803 if dataset == 'referit' else 1917498
    emb_name = dataset

    if pre_emb:
        print("Use pretrained Embeddings.")
        model = get_segmentation_model(model_name,
                                       mode='train',
                                       vocab_size=vocab_size,
                                       start_lr=lr,
                                       batch_size=bs,
                                       conv5=conv5,
                                       emb_name=emb_name,
                                       emb_dir=emb_dir,
                                       freeze_bn=args.freeze_bn,
                                       is_aug=args.is_aug)
    else:
        model = get_segmentation_model(model_name,
                                       mode='train',
                                       vocab_size=vocab_size,
                                       start_lr=lr,
                                       batch_size=bs,
                                       conv5=conv5)
    if finetune:
        weights = os.path.join(pretrain_path)
        snapshot_loader = tf.train.Saver()
    else:
        weights = './data/weights/deeplab_resnet_init.ckpt'
        print("Loading pretrained weights from {}".format(weights))
        load_var = {
            var.op.name: var
            for var in tf.global_variables()
            if var.name.startswith('res') or var.name.startswith('bn')
            or var.name.startswith('conv1') or var.name.startswith('Adam')
        }
        snapshot_loader = tf.train.Saver(load_var)

    snapshot_saver = tf.train.Saver(max_to_keep=4)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    snapshot_loader.restore(sess, weights)
    # Log tensorboard
    train_writer = tf.summary.FileWriter(args.log_dir + '/train', sess.graph)

    im_h, im_w, num_steps = model.H, model.W, model.num_steps
    text_batch = np.zeros((bs, num_steps), dtype=np.float32)
    image_batch = np.zeros((bs, im_h, im_w, 3), dtype=np.float32)
    mask_batch = np.zeros((bs, im_h, im_w, 1), dtype=np.float32)
    seq_len_batch = np.zeros(bs, dtype=np.int32)
    valid_idx_batch = np.zeros(bs, dtype=np.int32)

    if dataset == 'refvos':
        reader = data_reader_refvos.DataReader(im_dir=args.im_dir,
                                               mask_dir=args.mask_dir,
                                               train_metadata=args.meta)

    # for time calculate
    last_time = time.time()
    time_avg = MovingAverage()
    meanIoU = 0
    last_epoch = (last_iter * bs) // reader.num_batch
    for n_iter in range(last_iter + 1, max_iter):
        for n_batch in range(bs):
            batch = reader.read_batch(
                is_log=(n_batch == 0 and n_iter % iters_per_log == 0))
            text = batch['text_batch']
            im = batch['im_batch'].astype(np.float32)
            # mask = batch['mask_batch']
            mask = np.expand_dims(batch['mask_batch'].astype(np.float32),
                                  axis=2)
            seq_len = batch['seq_length']
            im = im[:, :, ::-1]
            im -= mu

            text_batch[n_batch, ...] = text
            image_batch[n_batch, ...] = im
            mask_batch[n_batch, ...] = mask
            seq_len_batch[n_batch] = seq_len

        _, train_step, summary = sess.run(
            [
                model.train,
                model.train_step,
                model.merged,
            ],
            feed_dict={
                model.words: text_batch,
                model.im: image_batch,
                model.target_fine: mask_batch,
                model.seq_len: seq_len_batch,
            })
        # cls_loss_avg = decay * cls_loss_avg + (1 - decay) * cls_loss_val
        # cls_loss_avg
        # Accuracy
        # accuracy_all, accuracy_pos, accuracy_neg = compute_accuracy(scores_val, label_val)
        # avg_accuracy_all = decay * avg_accuracy_all + (1 - decay) * accuracy_all
        # avg_accuracy_pos = decay * avg_accuracy_pos + (1 - decay) * accuracy_pos
        # avg_accuracy_neg = decay * avg_accuracy_neg + (1 - decay) * accuracy_neg
        # IoU = compute_meanIoU(scores_val, mask_batch)
        # meanIoU += IoU
        # timing
        cur_time = time.time()
        elapsed = cur_time - last_time
        last_time = cur_time
        train_writer.add_summary(summary, train_step)
        # if n_iter % iters_per_log == 0:
        #     print('iter = %d, loss (cur) = %f, loss (avg) = %f, lr = %f'
        #           % (n_iter, cls_loss_val, cls_loss_avg, lr_val))
        #     print('iter = %d, accuracy (cur) = %f (all), %f (pos), %f (neg)'
        #           % (n_iter, accuracy_all, accuracy_pos, accuracy_neg))
        #     print('iter = %d, accuracy (avg) = %f (all), %f (pos), %f (neg)'
        #           % (n_iter, avg_accuracy_all, avg_accuracy_pos, avg_accuracy_neg))
        #     print('iter = %d, meanIoU = %f (neg)'
        #           % (n_iter, meanIoU / iters_per_log))
        #     meanIoU = 0
        #     time_avg.add(elapsed)
        #     print('iter = %d, cur time = %.5f, avg time = %.5f, model_name: %s' % (n_iter, elapsed, time_avg.get_avg(), model_name))

        # Save snapshot
        if (n_iter * bs // reader.num_batch > last_epoch):
            last_epoch += 1
            snapshot_saver.save(sess, snapshot_file, global_step=train_step)
            print('snapshot saved at iteration {}'.format(n_iter))
        if (n_iter + 1) % snapshot == 0 or (n_iter + 1) >= max_iter:
            snapshot_saver.save(sess, snapshot_file, global_step=train_step)
            print('snapshot saved at iteration {}'.format(n_iter))
        if (n_iter + 1) >= stop_iter:
            print('stop training at iter ' + str(stop_iter))
            break

    print('Optimization done.')