def train(weights_name, data_base_dir, captions_base_dir):
    dataset_base_dir = os.path.join(data_base_dir, 'train')
    caption_json_path = os.path.join(captions_base_dir,
                                     'sentence_instance_train.json')
    vocab_dict = text_processing.load_vocab_dict_from_file(FLAGS.vocab_path)

    snapshot_file = os.path.join(
        FLAGS.snapshot_root,
        weights_name + '_' + FLAGS.model_name + '_iter_%d.tfmodel')
    os.makedirs(FLAGS.snapshot_root, exist_ok=True)
    os.makedirs(FLAGS.log_root, exist_ok=True)

    cls_loss_avg = 0
    decay = 0.99

    start_iter = 0

    model = RMI_model(mode='train',
                      vocab_size=FLAGS.vocab_size,
                      weights=weights_name)

    print('-' * 100)

    # Calculate trainable params.
    print('Network params:')
    t_vars = tf.trainable_variables()
    count_t_vars = 0
    for var in t_vars:
        num_param = np.prod(var.get_shape().as_list())
        count_t_vars += num_param
        print('%s | shape: %s | num_param: %i' %
              (var.name, str(var.get_shape()), num_param))
    print('Total network variables %i.' % count_t_vars)
    print('-' * 100)

    print('Optimizing params:')
    for var in model.optim_params:
        print('%s | shape: %s' % (var.name, str(var.get_shape())))
    print('-' * 100)

    snapshot_saver = tf.train.Saver(max_to_keep=10)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(FLAGS.snapshot_root)
    if not ckpt:
        if weights_name == 'deeplab':
            ckpt = tf.train.get_checkpoint_state(
                './models/SketchyScene_DeepLabv2')
            load_var = {
                var.op.name: var
                for var in tf.global_variables()
                if var.op.name.startswith('ResNet/group')
            }
        elif weights_name == 'fcn_8s':
            ckpt = tf.train.get_checkpoint_state('./models/SketchyScene_FCN8s')
            load_var = {
                var.op.name: var
                for var in tf.global_variables()
                if var.op.name.startswith('FCN_8s')
            }
        elif weights_name == 'segnet':
            ckpt = tf.train.get_checkpoint_state(
                './models/SketchyScene_SegNet')
            load_var = {
                var.op.name: var
                for var in tf.global_variables()
                if var.op.name.startswith('SegNet')
            }
        elif weights_name == 'deeplab_v3plus':
            ckpt = tf.train.get_checkpoint_state(
                './models/SketchyScene_DeepLabv3plus')
            load_var = {
                var.op.name: var
                for var in tf.global_variables()
                if var.op.name.startswith('resnet_v1_101')
            }
        else:
            raise ValueError('Unknown weights_name %s' % weights_name)

        snapshot_loader = tf.train.Saver(load_var)
        print('firstly train, loaded', ckpt.model_checkpoint_path)
        snapshot_loader.restore(sess,
                                ckpt.model_checkpoint_path)  # pretrained_model
    else:
        snapshot_path = ckpt.model_checkpoint_path
        print('loaded', snapshot_path)
        snapshot_saver.restore(sess, snapshot_path)
        start_iter = int(snapshot_path[snapshot_path.rfind('_') +
                                       1:snapshot_path.rfind('.')])

    summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(FLAGS.log_root, graph=sess.graph)

    duration_time_n_step = 0

    fp = open(caption_json_path, "r")
    json_data = fp.read()
    json_data = json.loads(json_data)
    print('data_len', len(json_data))

    train_info_list = []
    for i in range(len(json_data)):
        img_idx = json_data[i]['key']

        sen_instIdx_map = json_data[i]['sen_instIdx_map']
        sen_instIdx_map_keys = list(sen_instIdx_map.keys())

        for inst_data_idx in range(len(sen_instIdx_map_keys)):
            caption = sen_instIdx_map_keys[inst_data_idx]
            inst_indices = sen_instIdx_map[caption]

            tuple_map = {
                'img_idx': img_idx,
                'inst_indices': inst_indices,
                'caption': caption
            }
            train_info_list.append(tuple_map)

    train_info_indices = np.arange(len(train_info_list))
    temp_data_idx = -1
    print(len(train_info_list), 'tuples of data.')

    print('start_iter', start_iter)

    for n_iter in range(start_iter, FLAGS.max_iteration):
        start_time = time.time()

        temp_data_idx = (temp_data_idx + 1) % len(train_info_list)
        if temp_data_idx == 0:
            random.shuffle(train_info_indices)

        data_idx = train_info_indices[temp_data_idx]
        img_idx = train_info_list[data_idx]['img_idx']
        inst_indices = train_info_list[data_idx]['inst_indices']
        caption_thin = train_info_list[data_idx]['caption']
        # print('img_idx', img_idx, '; inst_indices', inst_indices, ':', caption_thin)

        # Load image, and target mask
        sketch_image, target_mask = sketch_data_processing.load_data_gt(
            dataset_base_dir,
            img_idx,
            fast_version=True,
            inst_indices=inst_indices)
        sketch_image -= mu

        # load text and augment the caption with random attributes
        caption = text_processing.augment_the_caption_with_attr(caption_thin)
        vocab_indices, seq_len = text_processing.preprocess_sentence(
            caption, vocab_dict, FLAGS.MAX_LEN)
        # print(caption, vocab_indices, seq_len)

        feed_dict = {
            model.words:
            np.expand_dims(vocab_indices, axis=0),  # [N, MAX_LEN]
            model.sequence_lengths: [seq_len],  # [N]
            model.im:
            np.expand_dims(sketch_image, axis=0),  # [N, H, W, 3]
            model.target_mask:
            np.expand_dims(np.expand_dims(target_mask.astype(np.float32),
                                          axis=0),
                           axis=3),
        }

        _, cls_loss_, lr_, scores_ = sess.run([
            model.train_step, model.cls_loss, model.learning_rate, model.pred
        ],
                                              feed_dict=feed_dict)

        duration_time = time.time() - start_time
        duration_time_n_step += duration_time

        if n_iter % FLAGS.count_left_time_freq == 0:
            if n_iter != 0:
                cls_loss_avg = decay * cls_loss_avg + (1 - decay) * cls_loss_
                print('iter = %d, loss (cur) = %f, loss (avg) = %f, lr = %f' %
                      (n_iter, cls_loss_, cls_loss_avg, lr_))

                left_step = FLAGS.max_iteration - n_iter
                left_sec = left_step / FLAGS.count_left_time_freq * duration_time_n_step
                print("### duration_time_%d_step:%.3f(sec), left time:%s\n" %
                      (FLAGS.count_left_time_freq, duration_time_n_step,
                       str(timedelta(seconds=left_sec))))
                duration_time_n_step = 0

        if n_iter % FLAGS.summary_write_freq == 0:
            if n_iter != 0:
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, n_iter)
                summary_writer.flush()

        # Save model
        if (n_iter + 1) % FLAGS.save_model_freq == 0 or (
                n_iter + 1) >= FLAGS.max_iteration:
            snapshot_saver.save(sess, snapshot_file % (n_iter + 1))
            print('model saved to ' + snapshot_file % (n_iter + 1))

    print('Optimization done.')
def inference(weights_name, data_base_dir, dataset_split, seg_data_base_dir,
              image_id, input_text):
    sketch_image_base_dir = os.path.join(data_base_dir, dataset_split,
                                         'DRAWING_GT')
    vocab_dict = text_processing.load_vocab_dict_from_file(FLAGS.vocab_path)

    dataset_class_names = ['bg']
    color_map_mat_path = os.path.join(data_base_dir, 'colorMapC46.mat')
    colorMap = scipy.io.loadmat(color_map_mat_path)['colorMap']
    for i in range(46):
        cat_name = colorMap[i][0][0]
        dataset_class_names.append(cat_name)

    score_thresh = 1e-9

    model = RMI_model(mode='eval',
                      vocab_size=FLAGS.vocab_size,
                      weights=weights_name)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Load pretrained model
    snapshot_restorer = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(FLAGS.snapshot_root)
    print('Restore:', ckpt.model_checkpoint_path)
    snapshot_restorer.restore(sess, ckpt.model_checkpoint_path)

    # Load image
    sketch_image = sketch_data_processing.load_image(
        sketch_image_base_dir, image_id)  # [768, 768, 3], float32
    sketch_image_vis = np.array(np.squeeze(sketch_image), dtype=np.uint8)
    sketch_image -= mu

    bin_drawing = sketch_image_vis.copy()[:, :, 0]
    bin_drawing[bin_drawing == 0] = 1
    bin_drawing[bin_drawing == 255] = 0

    caption = input_text
    vocab_indices, seq_len = text_processing.preprocess_sentence(
        caption, vocab_dict, FLAGS.MAX_LEN)

    up_val, sigm_val = sess.run(
        [model.up, model.sigm],
        feed_dict={
            model.words: np.expand_dims(vocab_indices, axis=0),  # [N, T]
            model.sequence_lengths: [seq_len],  # [N]
            model.im: np.expand_dims(sketch_image, axis=0),  # [N, H, W, 3]
        })

    up_val = np.squeeze(up_val)  # shape = [768, 768]
    predicts = (up_val >= score_thresh).astype(np.float32)  # 0.0/1.0
    predicts = predicts * bin_drawing  # [768, 768] {0, 1}

    # get pred_instance_mask by segm_data and predicts
    segm_data_npz_path = os.path.join(seg_data_base_dir, dataset_split,
                                      'seg_data',
                                      str(image_id) + '_datas.npz')
    pred_masks, pred_scores, pred_boxes, pred_class_ids = sketch_data_processing.get_pred_instance_mask(
        segm_data_npz_path, predicts.copy())
    print('pred_masks', pred_masks.shape)
    print('pred_scores', pred_scores.shape, pred_scores)
    print('pred_boxes', pred_boxes.shape)
    print('pred_class_ids', pred_class_ids.shape, pred_class_ids)

    # visualization_util.visualize_sem_seg(sketch_image_vis.copy(), predicts, 'Binary pred: ' + caption,
    #                                      save_path=os.path.join(FLAGS.match_result_root, 'seg_vis_bin.png'))
    # visualization_util.visualize_inst_seg(sketch_image_vis.copy(), pred_masks, 'Instance pred: ' + caption)
    # visualize_instance.display_instances(sketch_image_vis.copy(), pred_boxes, pred_masks, pred_class_ids,
    #                                      dataset_class_names, scores=None, title=caption,
    #                                      save_path=save_path, fix_color=False)

    visualization_util.visualize_sem_inst_mask(sketch_image_vis.copy(),
                                               predicts, pred_boxes,
                                               pred_masks, pred_class_ids,
                                               dataset_class_names, caption)
def test(weights_name, data_base_dir, dataset_split, captions_base_dir,
         seg_data_base_dir, cal_mask_AP, visualize):
    dataset_base_dir = os.path.join(data_base_dir, dataset_split)
    caption_json_path = os.path.join(
        captions_base_dir, 'sentence_instance_' + dataset_split + '.json')
    vocab_dict = text_processing.load_vocab_dict_from_file(FLAGS.vocab_path)

    os.makedirs(FLAGS.eval_result_root, exist_ok=True)

    score_thresh = 1e-9
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    cum_I, cum_U = 0, 0
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)

    APs = []
    iou_threshold = None  # None for mAP@[0.5:0.95]

    seg_total = 0.

    model = RMI_model(mode='eval',
                      vocab_size=FLAGS.vocab_size,
                      weights=weights_name)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Load pretrained model
    snapshot_restorer = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(FLAGS.snapshot_root)
    print('Restore:', ckpt.model_checkpoint_path)
    snapshot_restorer.restore(sess, ckpt.model_checkpoint_path)

    fp = open(caption_json_path, "r")
    json_data = fp.read()
    json_data = json.loads(json_data)
    print('data_len', len(json_data))

    for data_idx in range(len(json_data)):
        img_idx = json_data[data_idx]['key']
        print('Processing', data_idx + 1, '/', len(json_data), ', img_idx:',
              img_idx)

        # Load image, and target mask
        sketch_image, gt_class_ids, gt_bboxes, gt_masks = sketch_data_processing.load_data_gt(
            dataset_base_dir, img_idx)
        sketch_image_vis = np.array(np.squeeze(sketch_image), dtype=np.uint8)
        sketch_image -= mu

        bin_drawing = sketch_image_vis.copy()[:, :, 0]
        bin_drawing[bin_drawing == 0] = 1
        bin_drawing[bin_drawing == 255] = 0

        # load text and target_mask
        sen_instIdx_map = json_data[data_idx]['sen_instIdx_map']
        sen_instIdx_map_keys = list(sen_instIdx_map.keys())

        segm_data_npz_path = os.path.join(seg_data_base_dir, dataset_split,
                                          'seg_data',
                                          str(img_idx) + '_datas.npz')

        for inst_data_idx in range(len(sen_instIdx_map_keys)):
            caption = sen_instIdx_map_keys[inst_data_idx]
            inst_indices = sen_instIdx_map[caption]
            target_mask = np.zeros((gt_masks.shape[0], gt_masks.shape[1]),
                                   dtype=np.int32)
            caption_gt_masks = np.zeros(
                (gt_masks.shape[0], gt_masks.shape[1], len(inst_indices)),
                dtype=np.int32)

            for t_i, inst_idx in enumerate(inst_indices):
                target_mask = np.logical_or(target_mask, gt_masks[:, :,
                                                                  inst_idx])
                caption_gt_masks[:, :, t_i] = gt_masks[:, :, inst_idx]

            # augment the caption with random attributes
            caption = text_processing.augment_the_caption_with_attr(caption)

            vocab_indices, seq_len = text_processing.preprocess_sentence(
                caption, vocab_dict, FLAGS.MAX_LEN)

            scores_val, up_val, sigm_val = sess.run(
                [model.pred, model.up, model.sigm],
                feed_dict={
                    model.words: np.expand_dims(vocab_indices,
                                                axis=0),  # [N, MAX_LEN]
                    model.sequence_lengths: [seq_len],  # [N]
                    model.im: np.expand_dims(sketch_image,
                                             axis=0),  # [N, H, W, 3]
                })

            up_val = np.squeeze(up_val)  # shape = [768, 768]
            pred_raw = (up_val >= score_thresh).astype(np.float32)  # 0.0/1.0
            predicts = im_processing.resize_and_crop(pred_raw,
                                                     target_mask.shape[0],
                                                     target_mask.shape[1])
            predicts = predicts * bin_drawing

            if visualize:
                save_dir = os.path.join(FLAGS.visualize_pred_base_dir,
                                        dataset_split, str(img_idx))
                os.makedirs(save_dir, exist_ok=True)

                save_path_gt = os.path.join(save_dir,
                                            str(inst_data_idx) + '_gt.png')
                visualization_util.visualize_sem_seg(sketch_image_vis.copy(),
                                                     target_mask,
                                                     'GT: ' + caption,
                                                     save_path_gt)

                save_path = os.path.join(save_dir,
                                         str(inst_data_idx) + '_out.png')
                visualization_util.visualize_sem_seg(sketch_image_vis.copy(),
                                                     predicts,
                                                     'Pred: ' + caption,
                                                     save_path)

            I, U = eval_tools.compute_mask_IU(predicts.copy(), target_mask)
            cum_I += I
            cum_U += U
            msg = 'cumulative IoU = %f' % (cum_I / cum_U)
            for n_eval_iou in range(len(eval_seg_iou_list)):
                eval_seg_iou = eval_seg_iou_list[n_eval_iou]
                seg_correct[n_eval_iou] += (I / U >= eval_seg_iou)

            # Compute mask AP
            if cal_mask_AP:
                # get pred_instance_mask by segm_data and predicts
                pred_masks, pred_scores, _, _ = sketch_data_processing.get_pred_instance_mask(
                    segm_data_npz_path, predicts.copy())
                # print('caption_gt_masks', caption_gt_masks.shape)
                # print('pred_masks', pred_masks.shape)
                # print('pred_scores', pred_scores.shape, pred_scores)
                # visualization_util.visualize_inst_seg(sketch_image_vis.copy(), pred_masks,
                #                                       'Instance pred: ' + caption)

                if iou_threshold is None:
                    iou_thresholds = np.linspace(.5,
                                                 0.95,
                                                 np.round(
                                                     (0.95 - .5) / .05) + 1,
                                                 endpoint=True)
                    AP_list = np.zeros([len(iou_thresholds)], dtype=np.float32)
                    if pred_scores.shape[0] != 0:
                        for j in range(len(iou_thresholds)):
                            iouThr = iou_thresholds[j]
                            AP_single_iouThr, precisions, recalls, overlaps = \
                                eval_tools.compute_ap(caption_gt_masks, pred_scores, pred_masks,
                                                      iou_threshold=iouThr)
                            AP_list[j] = AP_single_iouThr

                    AP = AP_list
                    # print('iou_thresholds', str(iou_thresholds), ', AP', AP)
                else:
                    if pred_scores.shape[0] != 0:
                        AP, precisions, recalls, overlaps = \
                            eval_tools.compute_ap(caption_gt_masks, pred_scores, pred_masks,
                                                  iou_threshold=iou_threshold)
                    else:
                        AP = 0
                    # print('iou_threshold', str(iou_threshold), ', AP', AP)

                APs.append(AP)

            # print(msg)
            seg_total += 1

    # Print results
    result_str = '\n' + ckpt.model_checkpoint_path + '\n'
    result_str += 'Segmentation evaluation (without DenseCRF):\n'
    for n_eval_iou in range(len(eval_seg_iou_list)):
        result_str += 'precision@%s = %f\n' % \
                      (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] / seg_total)
    result_str += 'overall IoU = %f\n' % (cum_I / cum_U)
    print(ckpt.model_checkpoint_path)
    print(dataset_split, 'overall IoU = %f\n' % (cum_I / cum_U))

    if cal_mask_AP:
        mAP = np.mean(APs)
        mAP_list = np.mean(APs, axis=0)
        if iou_threshold is None:
            iou_str = '@[0.5:0.95]'
        else:
            iou_str = '@[' + str(iou_threshold) + ']'

        print("iou_threshold: ", iou_str, ", mAP = ", mAP)
        result_str += 'iou_threshold %s,  mAP = %s\n' % (iou_str, str(mAP))

        if iou_threshold is None:
            print("mAP_list: ", mAP_list)
            result_str += 'mAP_list = %s\n' % (str(mAP_list))

    # write validation result to txt
    write_path = os.path.join(
        FLAGS.eval_result_root,
        weights_name + '_' + FLAGS.model_name + '_iter_' +
        str(FLAGS.max_iteration) + '_' + dataset_split + '_result.txt')
    fp = open(write_path, 'a')
    fp.write(result_str)
    fp.close()
def rmi_refvg_predictor(split='val',
                        eval_img_count=-1,
                        out_path='output/eval_refvg/rmi',
                        model_iter=750000,
                        dcrf=True,
                        mu=the_mu):
    pretrained_model = './_rmi/refvg/tfmodel/refvg_resnet_RMI_iter_' + str(
        model_iter) + '.tfmodel'

    data_loader = RMIRefVGLoader(split=split)
    vocab_size = len(data_loader.vocab_dict)

    score_thresh = 1e-9
    H, W = 320, 320

    model = RMI_model(H=H,
                      W=W,
                      mode='eval',
                      vocab_size=vocab_size,
                      weights='resnet')

    # Load pretrained model
    snapshot_restorer = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    snapshot_restorer.restore(sess, pretrained_model)

    predictions = dict()

    while not data_loader.is_end:
        img_id, task_id, im, mask, sent, text = data_loader.get_img_data(
            rand=False, is_train=False)
        mask = mask.astype(np.float32)

        proc_im = skimage.img_as_ubyte(im_processing.resize_and_pad(im, H, W))
        proc_im_ = proc_im.astype(np.float32)
        proc_im_ = proc_im_[:, :, ::-1]
        proc_im_ -= mu

        scores_val, up_val, sigm_val = sess.run(
            [model.pred, model.up, model.sigm],
            feed_dict={
                model.words: np.expand_dims(text, axis=0),
                model.im: np.expand_dims(proc_im_, axis=0)
            })

        # scores_val = np.squeeze(scores_val)
        # pred_raw = (scores_val >= score_thresh).astype(np.float32)
        up_val = np.squeeze(up_val)
        pred_raw = (up_val >= score_thresh).astype(np.float32)
        predicts = im_processing.resize_and_crop(pred_raw, mask.shape[0],
                                                 mask.shape[1])
        pred_mask = predicts
        if dcrf:
            # Dense CRF post-processing
            sigm_val = np.squeeze(sigm_val)
            d = densecrf.DenseCRF2D(W, H, 2)
            U = np.expand_dims(-np.log(sigm_val), axis=0)
            U_ = np.expand_dims(-np.log(1 - sigm_val), axis=0)
            unary = np.concatenate((U_, U), axis=0)
            unary = unary.reshape((2, -1))
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=3, compat=3)
            d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=proc_im, compat=10)
            Q = d.inference(5)
            pred_raw_dcrf = np.argmax(Q, axis=0).reshape(
                (H, W)).astype(np.float32)
            predicts_dcrf = im_processing.resize_and_crop(
                pred_raw_dcrf, mask.shape[0], mask.shape[1])
            pred_mask = predicts_dcrf

        if img_id not in predictions.keys():
            predictions[img_id] = dict()
        pred_mask = np.packbits(pred_mask.astype(np.bool))
        predictions[img_id][task_id] = {'pred_mask': pred_mask}
        print data_loader.img_idx, img_id, task_id

    if out_path is not None:
        print('rmi_refvg_predictor: saving predictions to %s ...' % out_path)
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        fname = split
        if eval_img_count > 0:
            fname += '_%d' % eval_img_count
        fname += '.npy'
        f_path = os.path.join(out_path, fname)
        np.save(f_path, predictions)
    print('RMI refvg predictor done!')
    return predictions
Beispiel #5
0
def test(modelname, iter, dataset, visualize, weights, setname, dcrf, mu):
    data_folder = './' + dataset + '/' + setname + '_batch/'
    data_prefix = dataset + '_' + setname
    if visualize:
        save_dir = './' + dataset + '/visualization/' + modelname + '_' + str(iter) + '/'
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
    pretrained_model = './' + dataset + '/tfmodel/' + dataset + '_' + weights + '_' + modelname + '_iter_' + str(iter) + '.tfmodel'
    
    score_thresh = 1e-9
    eval_seg_iou_list = [.5, .6, .7, .8, .9]
    cum_I, cum_U = 0, 0
    seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    if dcrf:
        cum_I_dcrf, cum_U_dcrf = 0, 0
        seg_correct_dcrf = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
    seg_total = 0.
    H, W = 320, 320
    vocab_size = 8803 if dataset == 'referit' else 12112

    if modelname == 'LSTM':
        model = LSTM_model(H=H, W=W, mode='eval', vocab_size=vocab_size, weights=weights)
    elif modelname == 'RMI':
        model = RMI_model(H=H, W=W, mode='eval', vocab_size=vocab_size, weights=weights)
    else:
        raise ValueError('Unknown model name %s' % (modelname))

    # Load pretrained model
    snapshot_restorer = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    snapshot_restorer.restore(sess, pretrained_model)
    reader = data_reader.DataReader(data_folder, data_prefix, shuffle=False)

    for n_iter in range(reader.num_batch):

        batch = reader.read_batch()
        text = batch['text_batch']
        im = batch['im_batch']
        mask = batch['mask_batch'].astype(np.float32)

        proc_im = skimage.img_as_ubyte(im_processing.resize_and_pad(im, H, W))
        proc_im_ = proc_im.astype(np.float32)
        proc_im_ = proc_im_[:,:,::-1]
        proc_im_ -= mu

        scores_val, up_val, sigm_val = sess.run([model.pred, model.up, model.sigm],
            feed_dict={
                model.words: np.expand_dims(text, axis=0),
                model.im: np.expand_dims(proc_im_, axis=0)
            })

        # scores_val = np.squeeze(scores_val)
        # pred_raw = (scores_val >= score_thresh).astype(np.float32)
        up_val = np.squeeze(up_val)
        pred_raw = (up_val >= score_thresh).astype(np.float32)
        predicts = im_processing.resize_and_crop(pred_raw, mask.shape[0], mask.shape[1])
        if dcrf:
            # Dense CRF post-processing
            sigm_val = np.squeeze(sigm_val)
            d = densecrf.DenseCRF2D(W, H, 2)
            U = np.expand_dims(-np.log(sigm_val), axis=0)
            U_ = np.expand_dims(-np.log(1 - sigm_val), axis=0)
            unary = np.concatenate((U_, U), axis=0)
            unary = unary.reshape((2, -1))
            d.setUnaryEnergy(unary)
            d.addPairwiseGaussian(sxy=3, compat=3)
            d.addPairwiseBilateral(sxy=20, srgb=3, rgbim=proc_im, compat=10)
            Q = d.inference(5)
            pred_raw_dcrf = np.argmax(Q, axis=0).reshape((H, W)).astype(np.float32)
            predicts_dcrf = im_processing.resize_and_crop(pred_raw_dcrf, mask.shape[0], mask.shape[1])

        if visualize:
            sent = batch['sent_batch'][0]
            visualize_seg(im, predicts, sent)
            if dcrf:
                visualize_seg(im, predicts_dcrf, sent)

        I, U = eval_tools.compute_mask_IU(predicts, mask)
        cum_I += I
        cum_U += U
        msg = 'cumulative IoU = %f' % (cum_I/cum_U)
        for n_eval_iou in range(len(eval_seg_iou_list)):
            eval_seg_iou = eval_seg_iou_list[n_eval_iou]
            seg_correct[n_eval_iou] += (I/U >= eval_seg_iou)
        if dcrf:
            I_dcrf, U_dcrf = eval_tools.compute_mask_IU(predicts_dcrf, mask)
            cum_I_dcrf += I_dcrf
            cum_U_dcrf += U_dcrf
            msg += '\tcumulative IoU (dcrf) = %f' % (cum_I_dcrf/cum_U_dcrf)
            for n_eval_iou in range(len(eval_seg_iou_list)):
                eval_seg_iou = eval_seg_iou_list[n_eval_iou]
                seg_correct_dcrf[n_eval_iou] += (I_dcrf/U_dcrf >= eval_seg_iou)
        print(msg)
        seg_total += 1

    # Print results
    print('Segmentation evaluation (without DenseCRF):')
    result_str = ''
    for n_eval_iou in range(len(eval_seg_iou_list)):
        result_str += 'precision@%s = %f\n' % \
            (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou]/seg_total)
    result_str += 'overall IoU = %f\n' % (cum_I/cum_U)
    print(result_str)
    if dcrf:
        print('Segmentation evaluation (with DenseCRF):')
        result_str = ''
        for n_eval_iou in range(len(eval_seg_iou_list)):
            result_str += 'precision@%s = %f\n' % \
                (str(eval_seg_iou_list[n_eval_iou]), seg_correct_dcrf[n_eval_iou]/seg_total)
        result_str += 'overall IoU = %f\n' % (cum_I_dcrf/cum_U_dcrf)
        print(result_str)
Beispiel #6
0
def train(modelname, max_iter, snapshot, dataset, weights, setname, mu):
    data_folder = './' + dataset + '/' + setname + '_batch/'
    data_prefix = dataset + '_' + setname
    tfmodel_folder = './' + dataset + '/tfmodel/'
    snapshot_file = tfmodel_folder + dataset + '_' + weights + '_' + modelname + '_iter_%d.tfmodel'
    if not os.path.isdir(tfmodel_folder):
        os.makedirs(tfmodel_folder)

    cls_loss_avg = 0
    avg_accuracy_all, avg_accuracy_pos, avg_accuracy_neg = 0, 0, 0
    decay = 0.99
    vocab_size = 8803 if dataset == 'referit' else 12112

    if modelname == 'LSTM':
        model = LSTM_model(mode='train', vocab_size=vocab_size, weights=weights)
    elif modelname == 'RMI':
        model = RMI_model(mode='train', vocab_size=vocab_size, weights=weights)
    else:
        raise ValueError('Unknown model name %s' % (modelname))

    if weights == 'resnet':
        pretrained_model = './external/TF-resnet/model/ResNet101_init.tfmodel'
        load_var = {var.op.name: var for var in tf.global_variables() if var.op.name.startswith('ResNet')}
    elif weights == 'deeplab':
        pretrained_model = './external/TF-deeplab/model/ResNet101_train.tfmodel'
        load_var = {var.op.name: var for var in tf.global_variables() if var.op.name.startswith('DeepLab/group')}

    snapshot_loader = tf.train.Saver(load_var)
    snapshot_saver = tf.train.Saver(max_to_keep = 1000)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    snapshot_loader.restore(sess, pretrained_model)

    reader = data_reader.DataReader(data_folder, data_prefix)
    for n_iter in range(max_iter):

        batch = reader.read_batch()
        text = batch['text_batch']
        im = batch['im_batch'].astype(np.float32)
        mask = np.expand_dims(batch['mask_batch'].astype(np.float32), axis=2)

        im = im[:,:,::-1]
        im -= mu

        _, cls_loss_val, lr_val, scores_val, label_val = sess.run([model.train_step, 
            model.cls_loss, 
            model.learning_rate, 
            model.pred, 
            model.target], 
            feed_dict={
                model.words: np.expand_dims(text, axis=0),
                model.im: np.expand_dims(im, axis=0),
                model.target_fine: np.expand_dims(mask, axis=0)
            })
        cls_loss_avg = decay*cls_loss_avg + (1-decay)*cls_loss_val
        print('iter = %d, loss (cur) = %f, loss (avg) = %f, lr = %f' % (n_iter, cls_loss_val, cls_loss_avg, lr_val))

        # Accuracy
        accuracy_all, accuracy_pos, accuracy_neg = compute_accuracy(scores_val, label_val)
        avg_accuracy_all = decay*avg_accuracy_all + (1-decay)*accuracy_all
        avg_accuracy_pos = decay*avg_accuracy_pos + (1-decay)*accuracy_pos
        avg_accuracy_neg = decay*avg_accuracy_neg + (1-decay)*accuracy_neg
        print('iter = %d, accuracy (cur) = %f (all), %f (pos), %f (neg)'
              % (n_iter, accuracy_all, accuracy_pos, accuracy_neg))
        print('iter = %d, accuracy (avg) = %f (all), %f (pos), %f (neg)'
              % (n_iter, avg_accuracy_all, avg_accuracy_pos, avg_accuracy_neg))

        # Save snapshot
        if (n_iter+1) % snapshot == 0 or (n_iter+1) >= max_iter:
            snapshot_saver.save(sess, snapshot_file % (n_iter+1))
            print('snapshot saved to ' + snapshot_file % (n_iter+1))

    print('Optimization done.')